In [1]:
import torch
import torch.nn as nn
import numpy as np

from splitters import get_imbalanced_client_loaders
from states import ClientState, AdamClientState
from compresors import compress_model, IdenticalCompressor, ExpDitheringCompressor, RandKCompressor
from validation_utils import total_evaluation
from algorithms.feen import feen_training_step
from algorithms.fedavg import fed_avg_training_step
from models import ResNet18

In [2]:
num_clients = 10
gamma = 1e-3
lmbd = 0.1
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

models = [ResNet18(num_clients).to(device) for _ in range(num_clients)]
client_optimizers = [torch.optim.Adam(model.parameters(), lr=1e-5) for model in models]
clients_states = [ClientState(model, lmbd, gamma) for model in models]
criterion = nn.CrossEntropyLoss()


In [3]:
local_train_loader, local_test_loader, common_loader, validation_loader = get_imbalanced_client_loaders(num_clients)

In [5]:
for i in range(100):
    feen_training_step(models, clients_states, client_optimizers, 
                       local_train_loader, common_loader, criterion, device)
    if i % 10 == 0:
        metrics = total_evaluation(models, local_test_loader, validation_loader, device)
        print(f"Ensemble accuracy {metrics['ensemble_accuracy']:.3f} || Single accuracy = {np.mean(metrics['single_accuracies']):.3f} +- {np.std(metrics['single_accuracies']):.3f} || Local accuracy = {np.mean(metrics['local_accuracies']):.3f} +- {np.std(metrics['local_accuracies']):.3f}")


Ensemble accuracy 0.216 || Single accuracy = 0.195 +- 0.025 || Local accuracy = 0.276 +- 0.030
Ensemble accuracy 0.299 || Single accuracy = 0.271 +- 0.009 || Local accuracy = 0.377 +- 0.026
Ensemble accuracy 0.338 || Single accuracy = 0.310 +- 0.016 || Local accuracy = 0.438 +- 0.030
Ensemble accuracy 0.366 || Single accuracy = 0.330 +- 0.017 || Local accuracy = 0.462 +- 0.029
Ensemble accuracy 0.396 || Single accuracy = 0.342 +- 0.017 || Local accuracy = 0.477 +- 0.025
Ensemble accuracy 0.407 || Single accuracy = 0.347 +- 0.017 || Local accuracy = 0.490 +- 0.029
Ensemble accuracy 0.431 || Single accuracy = 0.350 +- 0.017 || Local accuracy = 0.489 +- 0.029
Ensemble accuracy 0.439 || Single accuracy = 0.349 +- 0.018 || Local accuracy = 0.490 +- 0.030
Ensemble accuracy 0.444 || Single accuracy = 0.348 +- 0.020 || Local accuracy = 0.489 +- 0.034
Ensemble accuracy 0.450 || Single accuracy = 0.347 +- 0.018 || Local accuracy = 0.489 +- 0.033
