In [1]:
import os
import time

import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import matplotlib.pyplot as plt
%matplotlib inline

# jupyter
import nest_asyncio
nest_asyncio.apply()


from learning_util import load_dataset, preprocess, create_fl_dataset, create_model, save_ckpt, load_ckpt

In [2]:
FLAGS = _ = None
STIME = time.time()

In [3]:
import argparse

parser = argparse.ArgumentParser()
FLAGS, _ = parser.parse_known_args()

FLAGS.debug = True
FLAGS.img_width = 4
FLAGS.img_height = 375
FLAGS.input = 'ieeeaccess'
FLAGS.nclients = 10
FLAGS.val_size = 0.1
FLAGS.batch_size = 32
FLAGS.shuffle_buffer = 1024
FLAGS.num_epochs = 1
FLAGS.max_rounds = 5
FLAGS.output = 'output'
FLAGS.ckpt_load = False
FLAGS.ckpt_term = 10

DEBUG = FLAGS.debug

if DEBUG:
    print(f'Parsed arguments {FLAGS}')
    print(f'Unparsed arguments {_}')

Parsed arguments Namespace(batch_size=32, ckpt_load=False, ckpt_term=10, debug=True, img_height=375, img_width=4, input='ieeeaccess', max_rounds=5, nclients=10, num_epochs=1, output='output', shuffle_buffer=1024, val_size=0.1)
Unparsed arguments ['-f', '/home/harny/.local/share/jupyter/runtime/kernel-958b26af-965a-40ce-99b6-21451f11628b.json']


In [4]:
dataset, idx2lab, idx2cnt = load_dataset(FLAGS.input)
if DEBUG:
    print(dataset)

         idx           lab                                             vector
0       [14]  [voipbuster]  [[69, 0, 0, 111, 112, 133, 64, 0, 128, 17, 161...
1        [9]        [sftp]  [[69, 0, 5, 220, 26, 42, 64, 0, 128, 6, 242, 1...
2        [3]        [ftps]  [[69, 0, 5, 220, 115, 38, 64, 0, 128, 6, 153, ...
3        [5]     [hangout]  [[69, 0, 3, 229, 121, 214, 0, 0, 128, 17, 0, 0...
4       [14]  [voipbuster]  [[69, 0, 0, 111, 5, 100, 64, 0, 128, 17, 12, 5...
...      ...           ...                                                ...
163323   [6]     [icqchat]  [[69, 0, 0, 78, 24, 198, 0, 0, 128, 17, 53, 23...
163324   [6]     [icqchat]  [[69, 0, 0, 50, 83, 46, 0, 0, 1, 17, 17, 111, ...
163325   [6]     [icqchat]  [[69, 0, 0, 114, 109, 12, 64, 0, 128, 6, 113, ...
163326   [6]     [icqchat]  [[69, 0, 0, 114, 45, 244, 64, 0, 128, 6, 176, ...
163327   [6]     [icqchat]  [[69, 0, 0, 50, 83, 48, 0, 0, 1, 17, 17, 109, ...

[163328 rows x 3 columns]


In [5]:
img_shape = (FLAGS.img_height, FLAGS.img_width, 1) # gray scale

dataset['x'] = dataset['vector'].apply(preprocess(img_shape))
dataset['y'] = dataset['idx']

dataset = dataset.sample(frac=1).reset_index(drop=True)

if DEBUG:
    print(dataset)

         idx        lab                                             vector  \
0       [15]  [youtube]  [[69, 0, 10, 168, 213, 50, 0, 0, 58, 6, 76, 60...   
1        [7]  [netflix]  [[69, 0, 10, 168, 156, 47, 64, 0, 52, 6, 147, ...   
2        [4]    [gmail]  [[69, 0, 0, 65, 103, 68, 0, 0, 128, 17, 254, 1...   
3       [15]  [youtube]  [[69, 0, 10, 168, 111, 0, 0, 0, 58, 6, 178, 11...   
4       [15]  [youtube]  [[69, 0, 15, 226, 143, 225, 0, 0, 58, 6, 140, ...   
...      ...        ...                                                ...   
163323  [12]  [torrent]  [[69, 0, 5, 110, 16, 58, 64, 0, 49, 6, 11, 114...   
163324   [4]    [gmail]  [[69, 0, 5, 110, 110, 228, 64, 0, 128, 6, 229,...   
163325  [13]    [vimeo]  [[69, 0, 0, 235, 244, 242, 64, 0, 49, 6, 153, ...   
163326   [4]    [gmail]  [[69, 0, 3, 255, 101, 7, 0, 0, 128, 17, 48, 16...   
163327  [10]    [skype]  [[69, 0, 0, 55, 126, 123, 0, 0, 64, 17, 123, 2...   

                                                        x     y

In [6]:
raw_train_datasets, raw_test_dataset = create_fl_dataset(
    dataset, idx2lab, FLAGS.nclients, FLAGS.val_size, printable=DEBUG)

split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split train dataset 14691
split test dataset 16402


In [7]:
def client_fn(client_id):
    return raw_train_datasets[client_id].repeat(FLAGS.num_epochs).shuffle(FLAGS.shuffle_buffer).batch(FLAGS.batch_size)

client_data = tff.simulation.ClientData.from_clients_and_fn(
    client_ids=range(0, len(raw_train_datasets)),
    create_tf_dataset_for_client_fn=client_fn)
client_data = [client_data.create_tf_dataset_for_client(x) for x in range(0, len(raw_train_datasets))]

test_dataset = raw_test_dataset.shuffle(FLAGS.shuffle_buffer).batch(FLAGS.batch_size)

In [8]:
sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(test_dataset)))

In [9]:
def model_fn():
    keras_model = create_model(len(idx2lab), img_shape)
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=test_dataset.element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [10]:
iterative_process = tff.learning.build_federated_averaging_process(
        model_fn,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [11]:
state = iterative_process.initialize()

In [12]:
evaluation = tff.learning.build_federated_evaluation(model_fn)

In [13]:
metrics = {
    'rounds': 0,
    'loss': list(),
    'accuracy': list(),
    'val_loss': list(),
    'val_accuracy': list(),
}
path_output = os.path.join(FLAGS.output, f'c{FLAGS.nclients}_e{FLAGS.num_epochs}_r{FLAGS.max_rounds}')
if FLAGS.ckpt_load:
    state, metrics = load_ckpt(path_output, state)
    if DEBUG:
        print(f'Load completed: rounds {metrics["rounds"]}')

In [14]:
for rounds in range(metrics['rounds']+1, FLAGS.max_rounds+1):
    state, output = iterative_process.next(state, client_data)
    val_output = evaluation(state.model, [test_dataset])
    metrics['rounds'] = rounds
    metrics['loss'].append(output['train']['loss'])
    metrics['accuracy'].append(output['train']['sparse_categorical_accuracy'])
    metrics['val_loss'].append(val_output['loss'])
    metrics['val_accuracy'].append(val_output['sparse_categorical_accuracy'])
    if DEBUG:
        print((f'[{int(time.time()-STIME)}] rounds: {rounds}, '
               f'output: {output["train"]}, '
               f'val_output: {val_output}'))
    if rounds%FLAGS.ckpt_term == 0:
        save_ckpt(path_output, state, metrics, create_model, len(idx2lab), img_shape)
save_ckpt(path_output, state, metrics, create_model, len(idx2lab), img_shape)

[513] rounds: 1, output: OrderedDict([('sparse_categorical_accuracy', 0.52367437), ('loss', 1.4520663)]), val_output: OrderedDict([('sparse_categorical_accuracy', 0.76880866), ('loss', 0.96667546)])
[637] rounds: 2, output: OrderedDict([('sparse_categorical_accuracy', 0.7692533), ('loss', 0.8171528)]), val_output: OrderedDict([('sparse_categorical_accuracy', 0.82806975), ('loss', 0.615555)])
[769] rounds: 3, output: OrderedDict([('sparse_categorical_accuracy', 0.8396978), ('loss', 0.581441)]), val_output: OrderedDict([('sparse_categorical_accuracy', 0.846665), ('loss', 0.561372)])
[906] rounds: 4, output: OrderedDict([('sparse_categorical_accuracy', 0.86943024), ('loss', 0.476544)]), val_output: OrderedDict([('sparse_categorical_accuracy', 0.8886721), ('loss', 0.43721974)])
[1044] rounds: 5, output: OrderedDict([('sparse_categorical_accuracy', 0.88738), ('loss', 0.41798744)]), val_output: OrderedDict([('sparse_categorical_accuracy', 0.87733203), ('loss', 0.44450983)])


NameError: name 'tff' is not defined