In [None]:
from scipy.io import loadmat
from utils import *
data = loadmat('emnist-digits.mat')
import matplotlib.pyplot as plt
import numpy as np
from labels_utils import *
from features_utils import *

In [5]:
dataset = data['dataset'][0, 0]
train = dataset[0][0, 0]  
test = dataset[1][0, 0]  
mapping = dataset[2]

train_images = train['images']   # Shape: (N, 28*28)
train_labels = train['labels']  # Shape: (N, 1)
train_writers = train['writers']

X = train_images.astype(np.float32) / 255.0  # Normalize to [0, 1]
y = train_labels.flatten().astype(np.int64)

# Wrap into a datalist with a single client
datalist = [(X, y)]

test_images = test['images'].astype(np.float32) / 255.0
test_labels = test['labels'].flatten().astype(np.int64)

In [58]:
### Defined global Hyperparameters
T = 30      # number of global rounds
K = 5      # number of client GD steps
gamma = 0.05 # learning rate

# Baseline FedAvg


### FedAvg results

With only 1 client holding all the data

In [54]:



# Run FedAvg with 1 client
print("now training the baseline, i.e. fedAvg with one client holding all the data")
model = fedavg(datalist, T, K, gamma)

now training the baseline, i.e. fedAvg with one client holding all the data
round :  1
round :  2
round :  3
round :  4
round :  5
round :  6
round :  7
round :  8
round :  9
round :  10
round :  11
round :  12
round :  13
round :  14
round :  15
round :  16
round :  17
round :  18
round :  19
round :  20
round :  21
round :  22
round :  23
round :  24
round :  25
round :  26
round :  27
round :  28
round :  29
round :  30


In [None]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.8547250032424927


With n_clients

In [59]:
n_clients = 10

In [None]:
datalist = make_femnist_datasets(X,y,train,n_clients)
# Hyperparameters
print(f"case with {n_clients} clients, with feature distribution shift")
model = fedavg(datalist, T, K, gamma)

  by_writer[int(writer)].append(idx)


case with 10 clients, with feature distribution shift
round :  1
round :  2
round :  3
round :  4
round :  5
round :  6
round :  7
round :  8
round :  9
round :  10
round :  11
round :  12
round :  13
round :  14
round :  15
round :  16
round :  17
round :  18
round :  19
round :  20
round :  21
round :  22
round :  23
round :  24
round :  25
round :  26
round :  27
round :  28
round :  29
round :  30


In [61]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.8424999713897705


# Labels Shift

#### With Dirichlet distributions

In [None]:
n_clients = 10
beta = 0.1 
datalist = create_dirichlet_clients(X, y, n_clients, beta)

In [None]:
# Hyperparameters
print("case with 5 clients, beta=0.5 skewed distribution!")
model = fedavg(datalist, T, K, gamma)

case with 5 clients, beta=0.5 skewed distribution!
round :  1
round :  2
round :  3
round :  4
round :  5
round :  6
round :  7
round :  8
round :  9
round :  10
round :  11
round :  12
round :  13
round :  14
round :  15
round :  16
round :  17
round :  18
round :  19
round :  20
round :  21
round :  22
round :  23
round :  24
round :  25
round :  26
round :  27
round :  28
round :  29
round :  30


In [None]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.7575500011444092


### MOON results

In [None]:
n_clients = 10
beta = 0.1
datalist = create_dirichlet_clients(X, y, n_clients, beta)

In [None]:
model, loss_curve = fedavg_moon(datalist, T, K, gamma, mu=0.5)


In [None]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.9042999744415283


# Features Shift

In [None]:
n_clients = 10
datalist = make_femnist_datasets(X,y,train,n_clients)
# Hyperparameters
T = 30       # number of global rounds
K = 5      # number of client GD steps
gamma = 0.05 # learning rate
print(f"case with {n_clients} clients, with feature distribution shift")
model = fedavg(datalist, T, K, gamma)

case with 10 clients, with feature distribution shift
round :  1
round :  2
round :  3
round :  4
round :  5
round :  6
round :  7
round :  8
round :  9
round :  10
round :  11
round :  12
round :  13
round :  14
round :  15
round :  16
round :  17
round :  18
round :  19
round :  20
round :  21
round :  22
round :  23
round :  24
round :  25
round :  26
round :  27
round :  28
round :  29
round :  30


In [47]:
test_accuracy = evaluate(model, test_images, test_labels)
print(f"Test Accuracy with {n_clients} ): {test_accuracy * 100:.2f}%")

Test Accuracy with 10 ): 84.46%


In [50]:
def fedavg_disk(datalist, alphas_list, client_sizes, T, K, gamma):
    """
    Perform FedAvg with data-size weighting and sample-weighted loss.

    Args:
      datalist: list of tuples (X_tensor, y_tensor) per client
      alphas_list: list of alpha tensors (shape same as y_tensor) per client
      client_sizes: list of int N_k for each client (length K)
      T: number of communication rounds
      K: number of local GD steps per client per round
      gamma: learning rate for local updates

    Returns:
      global_model: trained global PyTorch model
    """
    n_clients = len(datalist)
    total_samples = sum(client_sizes)
    # Initialize global model
    global_model = SimpleNN()
    global_state = global_model.state_dict()

    # Precompute weights N_k / N
    weights = [Nk / total_samples for Nk in client_sizes]

    for t in range(1, T + 1):
        local_states = []
        # Broadcast & local training
        for i in range(n_clients):
            client_model = SimpleNN()
            client_model.load_state_dict(deepcopy(global_state))
            X_i, y_i = datalist[i]
            alpha_i = alphas_list[i]
            # Ensure data on same device as model
            X_i = torch.tensor(X_i, dtype=torch.float32)
            y_i = torch.tensor(y_i, dtype=torch.long)
                        # 3c) Clip & renormalize α to avoid extremely large weights
            if not isinstance(alpha_i, torch.Tensor):
                alpha_i = torch.tensor(alpha_i, dtype=torch.float32)
            alpha_i = torch.clamp(alpha_i, max=10.0)         # clip step
            alpha_i = alpha_i * (len(alpha_i) / alpha_i.sum())         # now sum(alpha_i)== N_k
            # Perform K local steps
            updated_model = client_update(client_model, X_i, y_i, alpha_i, K, gamma)
            local_states.append(deepcopy(updated_model.state_dict()))

        # Aggregate weighted by client_sizes
        new_global_state = deepcopy(global_state)
        for key in global_state.keys():
            # Weighted sum of parameters
            new_global_state[key] = sum(weights[i] * local_states[i][key] for i in range(n_clients))
        global_state = new_global_state
        global_model.load_state_dict(global_state)

    return global_model

In [51]:
# 1) Create MADE data loaders (unchanged)
made_loaders = [
    DataLoader(
        TensorDataset(
            torch.tensor(X, dtype=torch.float32),
            torch.zeros(len(X), dtype=torch.float32)
        ),
        batch_size=64,
        shuffle=True
    )
    for X, _ in datalist
]

# 2) Train global MADE (unchanged)
global_made = train_global_made(
    made_loaders,
    dim=28*28,
    hid=100,
    rounds=T,
    local_epochs=1
)

# 3) Compute sample-weights α for each client (unchanged)
sample_weights = []
for ld in made_loaders:
    local_made = MADE(28*28, 100)
    local_state = train_local_made(local_made, ld, epochs=1)
    local_made.load_state_dict(local_state)

    alpha = compute_sample_weights(global_made, local_made, ld)
    # alpha is a FloatTensor of shape [n_samples] for this client
    sample_weights.append(alpha)

# ─────────────────────────────────────────────────────────────────────────────
# 4) Instead of oversampling, just record:
#      - alphas_list (one α-tensor per client)
#      - client_sizes (one integer per client)
#    and leave datalist as-is.

alphas_list = sample_weights                  # list of length n_clients, each α has shape (N_k,)
client_sizes = [ len(y) for (_, y) in datalist ]  # e.g. [N_1, N_2, ..., N_K]

# 5) Call the new fedavg signature directly on (X,y), alphas_list, and client_sizes
print(f"case with {len(datalist)} clients, with feature distribution shift")
model = fedavg_disk(
    datalist,       # unchanged: [(X1, y1), (X2, y2), …]
    alphas_list,    # per-sample weights from Phase 1
    client_sizes,   # [len(y1), len(y2), …]
    T, K, gamma
)

case with 10 clients, with feature distribution shift


In [52]:
# Evaluate
test_accuracy = evaluate(model, test_images, test_labels)
print(f"Test Accuracy with {n_clients} ): {test_accuracy * 100:.2f}%")

Test Accuracy with 10 ): 89.84%
