In [None]:
import os
from operator import itemgetter
import collections

import matplotlib.pylab as plt
# %matplotlib widget
%matplotlib inline

import tensorflow as tf
tf.random.set_seed(99)
import tensorflow_federated as tff
import numpy as np
import nest_asyncio
nest_asyncio.apply()
tf.compat.v1.enable_v2_behavior() # https://www.tensorflow.org/api_docs/python/tf/compat/v1/enable_v2_behavior

print(f'Tensorflow version: {tf.__version__}')
print(f'Tensorflow Federated version: {tff.__version__}')

In [None]:
# Global variables
IMG_DATA = input('Image data path: ')
IMG_SHAPE = (375, 4)
BATCH_SIZE = 32
CLASSES = ['aim', 'email', 'facebook', 'ftps', 'gmail', 
           'hangout', 
           'icqchat',
           'netflix', 'scp', 'sftp',
           'skype', 'spotify', 'torrent', 'vimeo', 'voipbuster',
           'youtube']

In [None]:
%%time
# prepare dataset
dataset_root = os.path.abspath(os.path.expanduser(IMG_DATA))
print(f'Dataset root: {dataset_root}')

img_gen_op = {'classes': CLASSES, 'target_size': IMG_SHAPE, 'batch_size': BATCH_SIZE, 'color_mode': 'grayscale'}
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)

def gen_fn(args):
    data_path = args.decode('utf-8')
    return image_generator.flow_from_directory(data_path,
                                               **img_gen_op)

dataset_size = dict()
queue = [dataset_root]
while queue:
    path = queue.pop(0)
    with os.scandir(path) as it:
        for entry in it:
            if entry.is_dir():
                queue.append(entry.path)
            if entry.is_file():
                name = os.path.basename(os.path.dirname(os.path.dirname(entry.path)))
                dataset_size[name] = dataset_size.get(name, 0) + 1

dataset_dict = dict()
with os.scandir(dataset_root) as it:
    for entry in it:
        if entry.is_dir():
            name = os.path.basename(entry.path)
            ds = tf.data.Dataset.from_generator(gen_fn,
                                                args=[entry.path],
                                                output_types=(tf.float32, tf.float32),
                                                output_shapes=(tf.TensorShape((None, ) + IMG_SHAPE + (1, )), 
                                                               tf.TensorShape([None, len(CLASSES)]))
                                               )
            dataset_dict[name] = ds

In [None]:
def client_fn(client_id):
    return dataset_dict[client_id]

client_data = tff.simulation.ClientData.from_clients_and_fn(
                client_ids=list(dataset_dict.keys()),
                create_tf_dataset_for_client_fn=client_fn)

train_ids = list(dataset_dict.keys())
train_ids.remove('0')
# train_ids = ['1'] ## for experiment client each
dataset = [(client_data.create_tf_dataset_for_client(x), dataset_size[x]) for x in train_ids]

In [None]:
example_dataset = (client_data.create_tf_dataset_for_client(client_data.client_ids[0]),
                   dataset_size[client_data.client_ids[0]])
print(example_dataset)

In [None]:
# Algorithm
import statistics
# take_value = statistics.median(dataset_size.values())
take_value = None

In [None]:
def preprocess(dataset, take_value=None):
#     return dataset[0].take(np.ceil(dataset[1]/BATCH_SIZE))
    if take_value is None:
        take_value = dataset[1]
    else:
#         take_value = min(take_value, dataset[1])
        take_value = 36000
    return dataset[0].take(np.ceil(take_value/BATCH_SIZE))
    
preprocessed_example_dataset = preprocess(example_dataset, take_value)
sample_batch = tf.nest.map_structure(lambda x: x.numpy(), iter(preprocessed_example_dataset).next())
print(sample_batch[0].shape, sample_batch[1].shape)

In [None]:
federated_dataset = [preprocess(x, take_value) for x in dataset]

In [None]:
# For evaluation
test_dataset = (client_data.create_tf_dataset_for_client('0'), dataset_size['0'])
federated_test_data = [preprocess(test_dataset, None)]

In [None]:
def create_keras_model():
    # Create Deep and Wide CNN
    img_input = tf.keras.Input(shape=IMG_SHAPE+(1, ))
    features1 = tf.keras.layers.Conv2D(32, (1, 1), activation='relu')(img_input)
    features1 = tf.keras.layers.Flatten()(features1)

    features2 = tf.keras.layers.Conv2D(32, (1, 2), activation='relu')(img_input)
    features2 = tf.keras.layers.Flatten()(features2)

    features3 = tf.keras.layers.Conv2D(32, (1, 4), activation='relu')(img_input)
    features3 = tf.keras.layers.Flatten()(features3)

    features4 = tf.keras.layers.Conv2D(32, (2, 2), activation='relu')(img_input)
    features4 = tf.keras.layers.MaxPooling2D((2, 2), strides=(1, 1))(features4)
    features5 = tf.keras.layers.Conv2D(32, (2, 2), activation='relu')(features4)

    features4 = tf.keras.layers.Flatten()(features4)
    features5 = tf.keras.layers.Flatten()(features5)

    x = tf.keras.layers.concatenate([features1, features2, features3, features4, features5])

    pred = tf.keras.layers.Dense(len(CLASSES))(x)

    model = tf.keras.Model(inputs=[img_input],
                           outputs=[pred])
    
    return model

In [None]:
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(keras_model, 
                                         dummy_batch=sample_batch,
                                         loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.keras.metrics.CategoricalAccuracy()])

In [None]:
iterative_process = tff.learning.build_federated_averaging_process(model_fn,
                           client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
                           server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

In [None]:
str(iterative_process.initialize.type_signature)

In [None]:
evaluation = tff.learning.build_federated_evaluation(model_fn)

In [None]:
state = iterative_process.initialize()

In [None]:
NUM_ROUNDS = 100
MAX_STD = 0.001
loss = list()
accuracy = list()
val_loss = list()
val_accuracy = list()
for round_num in range(1, NUM_ROUNDS+1):
    state, metrics = iterative_process.next(state, federated_dataset)
    val_metrics = evaluation(state.model, federated_test_data)
    loss.append(metrics.loss)
    accuracy.append(metrics.categorical_accuracy)
    val_loss.append(val_metrics.loss)
    val_accuracy.append(val_metrics.categorical_accuracy)
    print(f'round: {round_num:2d}, loss: {metrics.loss}, test_accuracy: {val_metrics.categorical_accuracy}')
    if len(val_loss) > 3 and np.std(val_loss[-3:]) < MAX_STD:
        break

In [None]:
fig1 = plt.figure(figsize=(8, 8))
ax1 = fig1.add_subplot(2, 1, 1)
ax1.plot(accuracy, label='Training Accuracy')
ax1.plot(val_accuracy, label='Validation Accuracy')
ax1.legend(loc='lower right')
ax1.set_ylabel('Accuracy')
ax1.set_ylim([0, 1])
ax1.set_title('Training and Validation Accuracy')

ax2 = fig1.add_subplot(2, 1, 2)
ax2.plot(loss, label='Training Loss')
ax2.plot(val_loss, label='Validation Loss')
ax2.legend(loc='upper right')
ax2.set_ylabel('Cross Entropy')
ax2.set_ylim([0,max(ax2.get_ylim())])
ax2.set_title('Training and Validation Loss')
ax2.set_xlabel('epoch')

In [None]:
import pickle

with open('output.pickle', 'wb') as f:
    pickle.dump((accuracy, val_accuracy, loss, val_loss), f)
with open('output.pickle', 'rb') as f:
    print(max(pickle.load(f)[1]))
print(max(val_accuracy))