In [1]:
import os
from os.path import join
import errno
import argparse
import sys
import pickle
import json 

import numpy as np
from tensorflow.keras.models import load_model
import tensorflow as tf

from data_utils import load_MNIST_data, load_FEMNIST_data, load_EMNIST_data, generate_bal_private_data
from data_utils import generate_partial_data, load_ready_data
from FedMD import FedMD, FedAvg
from Neural_Networks import train_models, cnn_2layer_fc_model, cnn_3layer_fc_model
from utility import * 

import pandas as pd            # For data manipulation
import seaborn as sns          # For plotting heatmap
import matplotlib.pyplot as plt  # For visualization and saving the plot
import logging



## Load config file

In [2]:
private_dataset_name = 'CIFAR10' # 'CIFAR10', 'CIFAR100', 'FEMNIST', 'MNIST

if private_dataset_name in ["CIFAR10", "CIFAR100"]:
    public_dataset_name = 'CIFAR10' if private_dataset_name == 'CIFAR100' else 'CIFAR100'
else : 
    public_dataset_name = 'MNIST' if private_dataset_name == 'FEMNIST' else 'FEMNIST'

print("private dataset: {0}".format(private_dataset_name))
print("public dataset: {0}".format(public_dataset_name))

CANDIDATE_MODELS = {"2_layer_CNN": cnn_2layer_fc_model, 
                    "3_layer_CNN": cnn_3layer_fc_model} 


if private_dataset_name in ["CIFAR10", "CIFAR100"]:
    conf_file = os.path.abspath("../conf/CIFAR_balance_conf.json")
else : 
    conf_file = os.path.abspath("../conf/MNIST_balance_conf.json")
with open(conf_file, "r") as f:
    conf_dict = json.load(f) 
    
    #n_classes = conf_dict["n_classes"]
    model_config = conf_dict["models"]
    pre_train_params = conf_dict["pre_train_params"]
    model_saved_dir = conf_dict["model_saved_dir"]
    model_saved_names = conf_dict["model_saved_names"]
    is_early_stopping = conf_dict["early_stopping"]
    public_classes = conf_dict["public_classes"]
    private_classes = conf_dict["private_classes"]
    n_classes = len(public_classes)
    
    
    N_parties = conf_dict["N_parties"]
    N_samples_per_class = conf_dict["N_samples_per_class"]
    
    N_rounds = conf_dict["N_rounds"]
    N_alignment = conf_dict["N_alignment"]
    N_private_training_round = conf_dict["N_private_training_round"]
    private_training_batchsize = conf_dict["private_training_batchsize"]
    N_logits_matching_round = conf_dict["N_logits_matching_round"]
    logits_matching_batchsize = conf_dict["logits_matching_batchsize"]
    aug = conf_dict["aug"]
    compress = conf_dict["compress"]
    select = conf_dict["select"]
    algorithm = conf_dict["algorithm"]
    
    dataset_dir = conf_dict["dataset_dir"]
    result_save_dir = conf_dict["result_save_dir"]
    
    if algorithm == 'fedavg':
        result_save_dir = result_save_dir + "_fedavg"
    
    elif algorithm == 'fedmd':
        result_save_dir = result_save_dir + "_fedmd"

        if aug : 
            print("adding aug")
            result_save_dir = result_save_dir + "_aug"
        if compress:
            print("adding compress")
            result_save_dir = result_save_dir + "_compress"
        if select:
            print("adding select")
            result_save_dir = result_save_dir + "_select"
        print("Using {} alignment".format(N_alignment))
        result_save_dir = result_save_dir + "_exp{}".format(N_alignment)

    if os.path.exists(result_save_dir):
        result_save_dir = result_save_dir + "_{}".format(np.random.randint(1000))
    os.makedirs(result_save_dir)


del conf_dict, conf_file


private dataset: CIFAR10
public dataset: CIFAR100


## Function

In [3]:

from PIL import Image

def all_digit(x) : 
    return all([c.isdigit() for c in x])

# resize image to shape 
def resize_this_image(x, shape, denormalize = True, normalize_back = True) : 
    if denormalize : 
        x = (x+0.5) * 255.0
        x = x.astype(np.uint8)
    y = np.array(Image.fromarray(x).resize(shape), dtype = np.float32) 
    if normalize_back : 
        y = y / 255.0 - 0.5
    return y


def resize_dataset(x, new_shape) : 
    num_images = len(x) 
    new_x = []
    for image in range(num_images) : 
        new_x.append(resize_this_image(x[image, ...], new_shape))
    new_x = np.array(new_x)
    return new_x


## Prepare dataset

In [4]:




dataset = 'CIFAR10' 
data_dir = "../data"
clients_data, alignment_data, test_data = load_ready_data(data_dir, dataset)


n_public_classes = len(alignment_data)
n_private_classes = len(test_data)  

In [5]:
algorithm = 'fedmd'
input_shape = (32, 32, 3) 
parties = [] 

for i in range(N_parties) : 
    item = model_config[i] 
    model_name = item['model_type']
    model_params = item['params']
    model = CANDIDATE_MODELS[model_name](n_classes = n_private_classes,
                                         input_shape = input_shape,
                                         **model_params)
    parties.append(model) 


len(parties) 

Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



2023-10-19 18:00:52.306537: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-10-19 18:00:52.306751: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
  super().__init__(name, **kwargs)


10

In [6]:
algorithms = {'fedavg': FedAvg, 'fedmd': FedMD}


N_rounds = 1

if algorithm == 'fedavg':
        alg = algorithms[algorithm](parties, clients_data, test_data, N_rounds = N_rounds,
                                    N_private_training_round = N_private_training_round,
                                    private_training_batchsize = private_training_batchsize)
elif algorithm == 'fedmd':
    alg = algorithms[algorithm](parties, 
                original_public_dataset = alignment_data,
                private_data = clients_data, 
                private_test_data = test_data,
                N_rounds = N_rounds,N_alignment = N_alignment,
                N_logits_matching_round = N_logits_matching_round,
                logits_matching_batchsize = logits_matching_batchsize, 
                 N_private_training_round = N_private_training_round, 
                private_training_batchsize = private_training_batchsize,
                aug = aug, compress = compress, select = select)

collaboration_performance = alg.collaborative_training()




model  0
model  1
model  2
model  3
model  4
model  5
model  6
model  7
model  8
model  9
round  0
generated 3 alignment data in 1620708000000000.0 seconds:
alignment_data shape x:(100, 32, 32, 3)  Y:(100,)
update logits ... 


2023-10-19 18:00:54.824803: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2023-10-19 18:00:54.918942: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-10-19 18:00:59.862428: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-10-19 18:01:04.904986: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-10-19 18:01:10.285511: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-10-19 18:01:15.048412: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-10-19 18:01:20.396659: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114

local logits shape:  (10, 100, 1, 10)
global logits shape:  (100, 1, 10)
diff_l:  (10, 100, 1, 10)
mean_diff_l:  (100, 1, 10)
i:0  c:0
i:0  c:1
i:0  c:2
i:0  c:3
i:0  c:4
i:0  c:5
i:0  c:6
i:0  c:7
i:0  c:8
i:0  c:9
i:1  c:0
i:1  c:1
i:1  c:2
i:1  c:3
i:1  c:4
i:1  c:5
i:1  c:6
i:1  c:7
i:1  c:8
i:1  c:9
i:2  c:0
i:2  c:1
i:2  c:2
i:2  c:3
i:2  c:4
i:2  c:5
i:2  c:6
i:2  c:7
i:2  c:8
i:2  c:9
i:3  c:0
i:3  c:1
i:3  c:2
i:3  c:3
i:3  c:4
i:3  c:5
i:3  c:6
i:3  c:7
i:3  c:8
i:3  c:9
i:4  c:0
i:4  c:1
i:4  c:2
i:4  c:3
i:4  c:4
i:4  c:5
i:4  c:6
i:4  c:7
i:4  c:8
i:4  c:9
i:5  c:0
i:5  c:1
i:5  c:2
i:5  c:3
i:5  c:4
i:5  c:5
i:5  c:6
i:5  c:7
i:5  c:8
i:5  c:9
i:6  c:0
i:6  c:1
i:6  c:2
i:6  c:3
i:6  c:4
i:6  c:5
i:6  c:6
i:6  c:7
i:6  c:8
i:6  c:9
i:7  c:0
i:7  c:1
i:7  c:2
i:7  c:3
i:7  c:4
i:7  c:5
i:7  c:6
i:7  c:7
i:7  c:8
i:7  c:9
i:8  c:0
i:8  c:1
i:8  c:2
i:8  c:3
i:8  c:4
i:8  c:5
i:8  c:6
i:8  c:7
i:8  c:8
i:8  c:9
i:9  c:0
i:9  c:1
i:9  c:2
i:9  c:3
i:9  c:4
i:9  c:5
i:9  c:6
i

2023-10-19 18:01:45.670642: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


model 0 done alignment
size of private_data: (50, 32, 32, 3) (50, 10)
dtype: float64 float32
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 reshape (Reshape)           (None, 32, 32, 3)         0         
                                                                 
 conv2d (Conv2D)             (None, 32, 32, 128)       3584      
                                                                 
 batch_normalization (BatchN  (None, 32, 32, 128)      512       
 ormalization)                                                   
                                                                 
 activation (Activation)     (None, 32, 32, 128)       0         
                                                                 
 dropout (Dropout)           (None

2023-10-19 18:01:46.925347: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


InvalidArgumentError: Graph execution error:

Detected at node 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits' defined at (most recent call last):
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2881, in run_cell
      result = self._run_cell(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2936, in _run_cell
      return runner(coro)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3135, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3338, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/fy/1qq1n6j95tn7gy2bwgbsqlhc0000gn/T/ipykernel_49245/309904467.py", line 22, in <cell line: 22>
      collaboration_performance = alg.collaborative_training()
    File "/Users/gadmohamed/Desktop/live repos/FedSKD/src/FedMD.py", line 291, in collaborative_training
      d["model_classifier"].fit(self.private_data[index][0], self.private_data[index][1],
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 994, in train_step
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1052, in compute_loss
      return self.compiled_loss(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 265, in __call__
      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/losses.py", line 152, in __call__
      losses = call_fn(y_true, y_pred)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/losses.py", line 272, in call
      return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/losses.py", line 2084, in sparse_categorical_crossentropy
      return backend.sparse_categorical_crossentropy(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/backend.py", line 5630, in sparse_categorical_crossentropy
      res = tf.nn.sparse_softmax_cross_entropy_with_logits(
Node: 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits'
Detected at node 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits' defined at (most recent call last):
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2881, in run_cell
      result = self._run_cell(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2936, in _run_cell
      return runner(coro)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3135, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3338, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/fy/1qq1n6j95tn7gy2bwgbsqlhc0000gn/T/ipykernel_49245/309904467.py", line 22, in <cell line: 22>
      collaboration_performance = alg.collaborative_training()
    File "/Users/gadmohamed/Desktop/live repos/FedSKD/src/FedMD.py", line 291, in collaborative_training
      d["model_classifier"].fit(self.private_data[index][0], self.private_data[index][1],
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 994, in train_step
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/training.py", line 1052, in compute_loss
      return self.compiled_loss(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 265, in __call__
      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/losses.py", line 152, in __call__
      losses = call_fn(y_true, y_pred)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/losses.py", line 272, in call
      return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/losses.py", line 2084, in sparse_categorical_crossentropy
      return backend.sparse_categorical_crossentropy(
    File "/Users/gadmohamed/miniforge3/envs/fl/lib/python3.9/site-packages/keras/backend.py", line 5630, in sparse_categorical_crossentropy
      res = tf.nn.sparse_softmax_cross_entropy_with_logits(
Node: 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  logits and labels must have the same first dimension, got logits shape [50,10] and labels shape [500]
	 [[{{node sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]]
	 [[Equal/_21]]
  (1) INVALID_ARGUMENT:  logits and labels must have the same first dimension, got logits shape [50,10] and labels shape [500]
	 [[{{node sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_70870]

In [None]:
with open(os.path.join(result_save_dir, 'col_performance.pkl'), 'wb') as f:
    pickle.dump(collaboration_performance, f, protocol=pickle.HIGHEST_PROTOCOL)


models_save_dir = join(result_save_dir, 'models')
os.makedirs(models_save_dir)

loss_fnn = tf.keras.losses.SparseCategoricalCrossentropy(reduction = 'none')
for i, d in enumerate(alg.collaborative_parties):
    model = d['model_classifier']
    train_preds, train_losses = model_stats(model, alg.tf_private_data[i], loss_fnn)
    test_preds, test_losses = model_stats(model, alg.tf_private_test_data, loss_fnn)

    model.save(os.path.join(models_save_dir, 'model_{}.h5').format(i))
    np.save(os.path.join(models_save_dir, 'train_preds_{}.npy').format(i), train_preds)
    np.save(os.path.join(models_save_dir, 'train_losses_{}.npy').format(i), train_losses)
    np.save(os.path.join(models_save_dir, 'test_preds_{}.npy').format(i), test_preds)
    np.save(os.path.join(models_save_dir, 'test_losses_{}.npy').format(i), test_losses)
