# Import

In [1]:
from data_utils import * 
from model_utils import *
from utils import FLClient, FLServer
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import os
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
from tensorflow.keras.utils import to_categorical

  from .autonotebook import tqdm as notebook_tqdm


# Functions

# Load CIFAR-10 for FL

In [2]:
dataset = 'cifar10'
num_classes = 10 if dataset == 'cifar10' else 100 
pub_num_classes = 100 if num_classes == 10 else 10
datadir = '../data'
partition = 'iid' 
n_parties = 5
beta = 0.5

(X_train, y_train, X_test, y_test, net_dataidx_map) = partition_data('cifar10', datadir=datadir, partition = partition, n_parties = n_parties, beta = beta)
(X_train_public, y_train_public, X_test_public, y_test_public, net_dataidx_map_public) = partition_data('cifar100', datadir=datadir, partition = 'iid', n_parties = 10, beta = 0.5)
# divide y_public by 10 to make it compatible with cifar10
y_train_public = y_train_public // 10
y_test_public = y_test_public // 10

y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)

# use num_classes instead of pub_num_classes to make it compatible with cifar10
y_train_public_cat = to_categorical(y_train_public, num_classes=num_classes)
y_test_public_cat = to_categorical(y_test_public, num_classes=num_classes)

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, y_train_cat.shape, y_test_cat.shape)
print(X_train_public.shape, y_train_public.shape, X_test_public.shape, y_test_public.shape, y_train_public_cat.shape, y_test_public_cat.shape)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) (50000, 10) (10000, 10)
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) (50000, 10) (10000, 10)


In [3]:
local_sets = [] 
test_sets = []
public_set = (X_train_public[net_dataidx_map_public[0]], y_train_public_cat[net_dataidx_map_public[0]])
for i in range(n_parties):
    local_sets.append((X_train[net_dataidx_map[i]], y_train_cat[net_dataidx_map[i]]))
    test_sets.append((X_test, y_test_cat))
    
print(public_set[0].shape, public_set[1].shape)
for i in range(n_parties):
    print('client ', i, ' ', local_sets[i][0].shape, local_sets[i][1].shape)
    print(test_sets[i][0].shape, test_sets[i][1].shape)
    print() 


(5000, 32, 32, 3) (5000, 10)
client  0   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  1   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  2   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  3   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  4   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)



In [4]:
aggregation_method = 'weights'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1
initial_pub_alignment_epochs = 4

fl_params = {
    'client_num': 3, 
    'tot_T': 20, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.003,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70,
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

In [5]:


for rr in range(fl_params['tot_T']) : 
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update(verbose = True)
    print(f"Round {rr} : avg_acc {avg_acc}, min_acc {min_acc}, max_acc {max_acc}, avg_train_acc {avg_train_acc}")




# server = FLServer(fl_params)
# for i in range(4) : 
#     accs, losses = train(server.clients[0].model, server.clients[0].local_dl, server.clients[0].optimizer)
#     print(f"Round {i} : accs {accs}, losses {losses}")


Epoch 0 | Train acc 0.11441693290734824 | Test acc 0.12809504792332269
Epoch 1 | Train acc 0.1422723642172524 | Test acc 0.15115814696485624
Epoch 2 | Train acc 0.15495207667731628 | Test acc 0.15485223642172524
Epoch 0 | Train acc 0.10283546325878594 | Test acc 0.10583067092651757
Epoch 1 | Train acc 0.1239017571884984 | Test acc 0.15215654952076677
Epoch 2 | Train acc 0.15744808306709265 | Test acc 0.16813099041533547
Epoch 0 | Train acc 0.10453274760383387 | Test acc 0.11281948881789138
Epoch 1 | Train acc 0.1326876996805112 | Test acc 0.15505191693290735
Epoch 2 | Train acc 0.16094249201277955 | Test acc 0.15894568690095848
Round 0 : avg_acc 0.16064297124600638, min_acc 0.15485223642172524, max_acc 0.16813099041533547, avg_train_acc 0.15778088391906284
Epoch 0 | Train acc 0.10083865814696485 | Test acc 0.09994009584664537
Epoch 1 | Train acc 0.10083865814696485 | Test acc 0.10003993610223642
Epoch 2 | Train acc 0.10363418530351437 | Test acc 0.10023961661341853
Epoch 0 | Train acc 

KeyboardInterrupt: 

In [10]:
aggregation_method = 'compressed_soft_labels'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1
initial_pub_alignment_epochs = 4 

fl_params = {
    'client_num': 3, 
    'tot_T': 20, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.003,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70,
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs, 
    'temperature': 0.6, 
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
soft_labels_server = FLServer(fl_params)


In [None]:
from utils import size_of 
soft_labels = server.clients[0].get_soft_labels()
compressed_soft_labels = soft_labels_server.clients[0].get_soft_labels(normalize = True, compress = True)
clientmodel = server.clients[0].model
ssize = size_of(soft_labels, size = 'KB')
cssize = size_of(compressed_soft_labels, size = 'KB')
clientmodel_state = clientmodel.state_dict()
msize = size_of(clientmodel_state, size = 'KB')

cssize, ssize, msize 


In [None]:
for rr in range(20) : 
    avg_acc, min_acc, max_acc, avg_train_acc = soft_labels_server.global_update(verbose = True) 
    print(f"Round {rr} : avg_acc {avg_acc}, min_acc {min_acc}, max_acc {max_acc}, avg_train_acc {avg_train_acc}")


Epoch 0 | Train acc 0.09644568690095846 | Test acc 0.09574680511182108
Epoch 1 | Train acc 0.10812699680511183 | Test acc 0.11112220447284345
Epoch 2 | Train acc 0.11950878594249201 | Test acc 0.12190495207667731
Epoch 0 | Train acc 0.130491214057508 | Test acc 0.16383785942492013
Epoch 1 | Train acc 0.1796126198083067 | Test acc 0.1876996805111821
Epoch 2 | Train acc 0.19349041533546327 | Test acc 0.19418929712460065
Epoch 0 | Train acc 0.10013977635782748 | Test acc 0.11671325878594249
Epoch 1 | Train acc 0.1387779552715655 | Test acc 0.15575079872204473
Epoch 2 | Train acc 0.1706269968051118 | Test acc 0.191194089456869
Round 0 : avg_acc 0.16909611288604898, min_acc 0.12190495207667731, max_acc 0.19418929712460065, avg_train_acc 0.16120873269435568
Epoch 0 | Train acc 0.19818290734824281 | Test acc 0.20746805111821087
Epoch 1 | Train acc 0.20666932907348243 | Test acc 0.20766773162939298
Epoch 2 | Train acc 0.2097643769968051 | Test acc 0.21475638977635783
Epoch 0 | Train acc 0.1917

KeyboardInterrupt: 

## FedAvg

In [None]:
aggregation_method = 'weights'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1

fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

client_accs = []
for c, client in enumerate(server.clients):
    acc = client.local_benchmark()
    client_accs.append(acc)
    print(f"Client {c} local benchmark accuracy: {acc}")
print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
print()

server.save_assets()


## FedMD

In [None]:
aggregation_method = 'soft_labels'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1


fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()


## CFedAKD

In [None]:
aggregation_method = 'compressed_soft_labels'
aug = True
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1


fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()


## FedAKD

In [None]:
aggregation_method = 'soft_labels'
aug = True
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1


fl_params = {
    'client_num': 2,
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70, 
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()
