# Import

In [1]:
from data_utils import * 
from model_utils import *
from utils import FLClient, FLServer
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import os
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
from tensorflow.keras.utils import to_categorical

  from .autonotebook import tqdm as notebook_tqdm


# Functions

# Load CIFAR-10 for FL

In [2]:
dataset = 'cifar10'
num_classes = 10 if dataset == 'cifar10' else 100 
pub_num_classes = 100 if num_classes == 10 else 10
datadir = '../data'
partition = 'iid' 
n_parties = 5
beta = 0.5

(X_train, y_train, X_test, y_test, net_dataidx_map) = partition_data('cifar10', datadir=datadir, partition = partition, n_parties = n_parties, beta = beta)
(X_train_public, y_train_public, X_test_public, y_test_public, net_dataidx_map_public) = partition_data('cifar100', datadir=datadir, partition = 'iid', n_parties = 50, beta = 0.5)
y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)
y_train_public_cat = to_categorical(y_train_public, num_classes=pub_num_classes)
y_test_public_cat = to_categorical(y_test_public, num_classes=pub_num_classes)

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, y_train_cat.shape, y_test_cat.shape)
print(X_train_public.shape, y_train_public.shape, X_test_public.shape, y_test_public.shape, y_train_public_cat.shape, y_test_public_cat.shape)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) (50000, 10) (10000, 10)
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) (50000, 100) (10000, 100)


In [3]:
local_sets = [] 
test_sets = []
public_set = (X_train_public[net_dataidx_map_public[0]], y_train_public_cat[net_dataidx_map_public[0]])
for i in range(n_parties):
    local_sets.append((X_train[net_dataidx_map[i]], y_train_cat[net_dataidx_map[i]]))
    test_sets.append((X_test, y_test_cat))
    
print(public_set[0].shape, public_set[1].shape)
for i in range(n_parties):
    print('client ', i, ' ', local_sets[i][0].shape, local_sets[i][1].shape)
    print(test_sets[i][0].shape, test_sets[i][1].shape)
    print() 


(1000, 32, 32, 3) (1000, 100)
client  0   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  1   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  2   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  3   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  4   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)



In [4]:
aggregation_method = 'soft_labels'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1

fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

In [5]:
aggregation_method = 'compressed_soft_labels'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1

fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
compressed_server = FLServer(fl_params)

In [7]:
from utils import size_of 
soft_labels = server.clients[0].get_soft_labels()
compressed_soft_labels = compressed_server.clients[0].get_soft_labels(normalize = True, compress = True)
clientmodel = server.clients[0].model
ssize = size_of(soft_labels, size = 'KB')
cssize = size_of(compressed_soft_labels, size = 'KB')
clientmodel_state = clientmodel.state_dict()
msize = size_of(clientmodel_state, size = 'KB')

cssize, ssize, msize 
#(10.0, 40.0, 400.832)

(10.0, 40.0, 400.832)

## FedAvg

In [4]:
aggregation_method = 'weights'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1

fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

client_accs = []
for c, client in enumerate(server.clients):
    acc = client.local_benchmark()
    client_accs.append(acc)
    print(f"Client {c} local benchmark accuracy: {acc}")
print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
print()

server.save_assets()


Round 0 accuracy: avg: 0.11605431309904153 min: 0.10113817891373802  max: 0.128694089456869 avg_train: 0.1244047619047619
Round 1 accuracy: avg: 0.1418929712460064 min: 0.12100638977635783  max: 0.15545127795527156 avg_train: 0.14484126984126983
Round 2 accuracy: avg: 0.15702875399361022 min: 0.13428514376996806  max: 0.16633386581469647 avg_train: 0.1626984126984127
Round 3 accuracy: avg: 0.17697683706070289 min: 0.1566493610223642  max: 0.2027755591054313 avg_train: 0.18115079365079367
Round 4 accuracy: avg: 0.18546325878594247 min: 0.15634984025559107  max: 0.20547124600638977 avg_train: 0.19563492063492063
Round 5 accuracy: avg: 0.20117811501597443 min: 0.18610223642172524  max: 0.22304313099041534 avg_train: 0.20684523809523808
Round 6 accuracy: avg: 0.20451277955271566 min: 0.14237220447284346  max: 0.23851837060702874 avg_train: 0.22123015873015875
Round 7 accuracy: avg: 0.22695686900958467 min: 0.21924920127795527  max: 0.23961661341853036 avg_train: 0.2373015873015873
Round 8 

## FedMD

In [5]:
aggregation_method = 'soft_labels'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1


fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()


Round 0 accuracy: avg: 0.11876996805111824 min: 0.11002396166134186  max: 0.13738019169329074 avg_train: 0.12152777777777776
Round 1 accuracy: avg: 0.14071485623003194 min: 0.10463258785942492  max: 0.17172523961661343 avg_train: 0.1403769841269841
Round 2 accuracy: avg: 0.14438897763578276 min: 0.10383386581469649  max: 0.18360623003194887 avg_train: 0.16071428571428575
Round 3 accuracy: avg: 0.1606629392971246 min: 0.10632987220447285  max: 0.21156150159744408 avg_train: 0.16686507936507936
Round 4 accuracy: avg: 0.1535143769968051 min: 0.1077276357827476  max: 0.21226038338658146 avg_train: 0.1601190476190476
Round 5 accuracy: avg: 0.16018370607028753 min: 0.10193690095846646  max: 0.21944888178913738 avg_train: 0.16875
Round 6 accuracy: avg: 0.1586461661341853 min: 0.11331869009584665  max: 0.20786741214057508 avg_train: 0.17628968253968255
Round 7 accuracy: avg: 0.17266373801916934 min: 0.11002396166134186  max: 0.24051517571884984 avg_train: 0.1824404761904762
Round 8 accuracy: a

## CFedAKD

In [6]:
aggregation_method = 'compressed_soft_labels'
aug = True
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1


fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()


Round 0 accuracy: avg: 0.12785543130990415 min: 0.10273562300319489  max: 0.1387779552715655 avg_train: 0.12916666666666665
Round 1 accuracy: avg: 0.14540734824281148 min: 0.12839456869009586  max: 0.1706269968051118 avg_train: 0.14970238095238095
Round 2 accuracy: avg: 0.15948482428115016 min: 0.1420726837060703  max: 0.17521964856230032 avg_train: 0.16279761904761908
Round 3 accuracy: avg: 0.178694089456869 min: 0.15595047923322683  max: 0.19408945686900958 avg_train: 0.17906746031746032
Round 4 accuracy: avg: 0.18871805111821086 min: 0.16114217252396165  max: 0.2163538338658147 avg_train: 0.1886904761904762
Round 5 accuracy: avg: 0.20309504792332272 min: 0.1790135782747604  max: 0.23552316293929712 avg_train: 0.20674603174603176
Round 6 accuracy: avg: 0.21122204472843448 min: 0.17641773162939298  max: 0.24141373801916932 avg_train: 0.21001984126984125
Round 7 accuracy: avg: 0.22186501597444086 min: 0.18829872204472844  max: 0.25139776357827476 avg_train: 0.22033730158730158
Round 8 

## FedAKD

In [7]:
aggregation_method = 'soft_labels'
aug = True
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1


fl_params = {
    'client_num': len(local_sets),
    'tot_T': 50, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 3, 
    'lr': 0.005,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 70
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()


Round 0 accuracy: avg: 0.13560303514376998 min: 0.11721246006389777  max: 0.15575079872204473 avg_train: 0.13819444444444445
Round 1 accuracy: avg: 0.14942092651757188 min: 0.10293530351437699  max: 0.17172523961661343 avg_train: 0.15049603174603174
Round 2 accuracy: avg: 0.1623402555910543 min: 0.1005391373801917  max: 0.20357428115015974 avg_train: 0.15823412698412698
Round 3 accuracy: avg: 0.17947284345047923 min: 0.10003993610223642  max: 0.22184504792332269 avg_train: 0.1736111111111111
Round 4 accuracy: avg: 0.17366214057507987 min: 0.10043929712460063  max: 0.21705271565495207 avg_train: 0.18759920634920632
Round 5 accuracy: avg: 0.18971645367412143 min: 0.10083865814696485  max: 0.2349241214057508 avg_train: 0.2021825396825397
Round 6 accuracy: avg: 0.19746405750798723 min: 0.10043929712460063  max: 0.2423123003194888 avg_train: 0.2115079365079365
Round 7 accuracy: avg: 0.2104233226837061 min: 0.10812699680511183  max: 0.25748801916932906 avg_train: 0.2170634920634921
Round 8 a