In [1]:
import os
from os.path import join
import errno
import argparse
import sys
import pickle
import json 

import numpy as np
from tensorflow.keras.models import load_model
import tensorflow as tf

from data_utils import load_MNIST_data, load_FEMNIST_data, load_EMNIST_data, generate_bal_private_data
from data_utils import generate_partial_data, load_ready_data
from FedMD import FedMD, FedAvg
from Neural_Networks import train_models, cnn_2layer_fc_model, cnn_3layer_fc_model
from utility import * 

import pandas as pd            # For data manipulation
import seaborn as sns          # For plotting heatmap
import matplotlib.pyplot as plt  # For visualization and saving the plot
import logging



## Load config file

In [2]:
private_dataset_name = 'CIFAR10' # 'CIFAR10', 'CIFAR100', 'FEMNIST', 'MNIST'



if private_dataset_name in ["CIFAR10", "CIFAR100"]:
    public_dataset_name = 'CIFAR10' if private_dataset_name == 'CIFAR100' else 'CIFAR100'
else : 
    public_dataset_name = 'MNIST' if private_dataset_name == 'FEMNIST' else 'FEMNIST'

print("private dataset: {0}".format(private_dataset_name))
print("public dataset: {0}".format(public_dataset_name))

CANDIDATE_MODELS = {"2_layer_CNN": cnn_2layer_fc_model, 
                    "3_layer_CNN": cnn_3layer_fc_model} 


if private_dataset_name in ["CIFAR10", "CIFAR100"]:
    conf_file = os.path.abspath("../conf/CIFAR_balance_conf.json")
else : 
    conf_file = os.path.abspath("../conf/MNIST_balance_conf.json")
with open(conf_file, "r") as f:
    conf_dict = json.load(f) 
    
    #n_classes = conf_dict["n_classes"]
    model_config = conf_dict["models"]
    pre_train_params = conf_dict["pre_train_params"]
    model_saved_dir = conf_dict["model_saved_dir"]
    model_saved_names = conf_dict["model_saved_names"]
    is_early_stopping = conf_dict["early_stopping"]
    public_classes = conf_dict["public_classes"]
    private_classes = conf_dict["private_classes"]
    n_classes = len(public_classes)
    
    
    N_parties = conf_dict["N_parties"]
    N_samples_per_class = conf_dict["N_samples_per_class"]
    
    N_rounds = conf_dict["N_rounds"]
    N_alignment = conf_dict["N_alignment"]
    N_private_training_round = conf_dict["N_private_training_round"]
    private_training_batchsize = conf_dict["private_training_batchsize"]
    N_logits_matching_round = conf_dict["N_logits_matching_round"]
    logits_matching_batchsize = conf_dict["logits_matching_batchsize"]
    aug = conf_dict["aug"]
    compress = conf_dict["compress"]
    select = conf_dict["select"]
    algorithm = conf_dict["algorithm"]
    
    dataset_dir = conf_dict["dataset_dir"]
    result_save_dir = conf_dict["result_save_dir"]
    
    if algorithm == 'fedavg':
        result_save_dir = result_save_dir + "_fedavg"
    
    elif algorithm == 'fedmd':
        result_save_dir = result_save_dir + "_fedmd"

        if aug : 
            print("adding aug")
            result_save_dir = result_save_dir + "_aug"
        if compress:
            print("adding compress")
            result_save_dir = result_save_dir + "_compress"
        if select:
            print("adding select")
            result_save_dir = result_save_dir + "_select"
        print("Using {} alignment".format(N_alignment))
        result_save_dir = result_save_dir + "_exp{}".format(N_alignment)

    if os.path.exists(result_save_dir):
        result_save_dir = result_save_dir + "_{}".format(np.random.randint(1000))
    os.makedirs(result_save_dir)


del conf_dict, conf_file


private dataset: CIFAR10
public dataset: CIFAR100


## Function

In [3]:

from PIL import Image

def all_digit(x) : 
    return all([c.isdigit() for c in x])

# resize image to shape 
def resize_this_image(x, shape, denormalize = True, normalize_back = True) : 
    if denormalize : 
        x = (x+0.5) * 255.0
        x = x.astype(np.uint8)
    y = np.array(Image.fromarray(x).resize(shape), dtype = np.float32) 
    if normalize_back : 
        y = y / 255.0 - 0.5
    return y


def resize_dataset(x, new_shape) : 
    num_images = len(x) 
    new_x = []
    for image in range(num_images) : 
        new_x.append(resize_this_image(x[image, ...], new_shape))
    new_x = np.array(new_x)
    return new_x


## Prepare dataset

In [4]:




dataset = private_dataset_name
data_dir = "../data"
clients_data, alignment_data, test_data = load_ready_data(data_dir, dataset)


n_public_classes = len(alignment_data)
n_private_classes = len(test_data)  

In [5]:
len(clients_data[0][0])

50

In [6]:
algorithm = 'fedavg'
input_shape = (32, 32, 3) 
parties = [] 

for i in range(N_parties) : 
    model_idx = i if algorithm != 'fedavg' else 0
    item = model_config[model_idx] 
    model_name = item['model_type']
    model_params = item['params']
    model = CANDIDATE_MODELS[model_name](n_classes = n_private_classes,
                                         input_shape = input_shape,
                                         **model_params)
    parties.append(model) 


len(parties) 

Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



2023-11-20 13:38:50.424749: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-11-20 13:38:50.425116: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
  super().__init__(name, **kwargs)


10

In [7]:
algorithms = {'fedavg': FedAvg, 'fedmd': FedMD}


N_rounds = 20

if algorithm == 'fedavg':
        alg = algorithms[algorithm](parties, clients_data, test_data, N_rounds = N_rounds,
                                    N_private_training_round = N_private_training_round,
                                    private_training_batchsize = private_training_batchsize)
elif algorithm == 'fedmd':
    alg = algorithms[algorithm](parties, 
                original_public_dataset = alignment_data,
                private_data = clients_data, 
                private_test_data = test_data,
                N_rounds = N_rounds,N_alignment = N_alignment,
                N_logits_matching_round = N_logits_matching_round,
                logits_matching_batchsize = logits_matching_batchsize, 
                 N_private_training_round = N_private_training_round, 
                private_training_batchsize = private_training_batchsize,
                aug = aug, compress = compress, select = select)

collaboration_performance = alg.collaborative_training()




model  0
model  1
model  2
model  3
model  4
model  5
model  6
model  7
model  8
model  9
round  0


2023-11-20 13:38:51.829269: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2023-11-20 13:38:52.171380: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:53.849576: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:54.417976: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:55.007118: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:55.580916: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2023-11-20 13:38:56.272983: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2023-11-20 13:38:56.854509: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:57.432471: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:58.025667: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:58.601606: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-11-20 13:38:59.210813: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 0 accuracy: 0.155


2023-11-20 13:39:02.836515: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 1 accuracy: 0.1523


2023-11-20 13:39:06.370841: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 2 accuracy: 0.1648


2023-11-20 13:39:10.054815: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 3 accuracy: 0.1608


2023-11-20 13:39:13.632675: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 4 accuracy: 0.1141


2023-11-20 13:39:17.176831: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 5 accuracy: 0.1709


2023-11-20 13:39:20.898933: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 6 accuracy: 0.1649


2023-11-20 13:39:24.512003: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 7 accuracy: 0.1517


2023-11-20 13:39:28.045229: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 8 accuracy: 0.1585


2023-11-20 13:39:31.560927: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 9 accuracy: 0.1271
round  1
model 0 accuracy: 0.1863
model 1 accuracy: 0.1818
model 2 accuracy: 0.2043
model 3 accuracy: 0.1957
model 4 accuracy: 0.1432
model 5 accuracy: 0.1966
model 6 accuracy: 0.1988
model 7 accuracy: 0.1841
model 8 accuracy: 0.1818
model 9 accuracy: 0.1216
round  2
model 0 accuracy: 0.1956
model 1 accuracy: 0.1904
model 2 accuracy: 0.2389
model 3 accuracy: 0.2232
model 4 accuracy: 0.1665
model 5 accuracy: 0.2077
model 6 accuracy: 0.2104
model 7 accuracy: 0.196
model 8 accuracy: 0.1813
model 9 accuracy: 0.107
round  3
model 0 accuracy: 0.1901
model 1 accuracy: 0.195
model 2 accuracy: 0.2269
model 3 accuracy: 0.2219
model 4 accuracy: 0.1731
model 5 accuracy: 0.2162
model 6 accuracy: 0.2074
model 7 accuracy: 0.1954
model 8 accuracy: 0.1809
model 9 accuracy: 0.1108
round  4
model 0 accuracy: 0.1807
model 1 accuracy: 0.1914
model 2 accuracy: 0.1953
model 3 accuracy: 0.1952
model 4 accuracy: 0.1728
model 5 accuracy: 0.2145
model 6 accuracy: 0.1957
model 7 accuracy:

In [None]:
with open(os.path.join(result_save_dir, 'col_performance.pkl'), 'wb') as f:
    pickle.dump(collaboration_performance, f, protocol=pickle.HIGHEST_PROTOCOL)


models_save_dir = join(result_save_dir, 'models')
os.makedirs(models_save_dir)

loss_fnn = tf.keras.losses.SparseCategoricalCrossentropy(reduction = 'none')
for i, d in enumerate(alg.collaborative_parties):
    model = d['model_classifier']
    train_preds, train_losses = model_stats(model, alg.tf_private_data[i], loss_fnn)
    test_preds, test_losses = model_stats(model, alg.tf_private_test_data, loss_fnn)

    model.save(os.path.join(models_save_dir, 'model_{}.h5').format(i))
    np.save(os.path.join(models_save_dir, 'train_preds_{}.npy').format(i), train_preds)
    np.save(os.path.join(models_save_dir, 'train_losses_{}.npy').format(i), train_losses)
    np.save(os.path.join(models_save_dir, 'test_preds_{}.npy').format(i), test_preds)
    np.save(os.path.join(models_save_dir, 'test_losses_{}.npy').format(i), test_losses)
