In [None]:
# Setup library
## install -r requirements.txt
import os
import csv
import pickle
import time

import pandas as pd
import numpy as np
# np.random.seed(99)

# jupyter
import nest_asyncio
nest_asyncio.apply()

import tensorflow as tf
# tf.random.set_seed(99)
import tensorflow_federated as tff

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
### Global variables
IMG_SHAPE = (375, 4)

path_pkl = ''
dataset = pd.read_pickle(path_pkl)

path_index = ''
with open(path_index, 'rb') as f:
    idx2lab, lab2cnt = pickle.load(f)

In [None]:
def pre_reshape(x):
    return np.expand_dims(np.array(x).reshape(IMG_SHAPE)/255, axis=-1)
dataset['x'] = dataset['x'].apply(pre_reshape)

In [None]:
NCLIENTS = 5

In [None]:
def create_fl_dataset(dataset, nclients=NCLIENTS, val_size=0.1):
    dataset = dataset.sample(frac=1)
    nrow = len(dataset)
    idx_val = [int(nrow*(1-val_size)), nrow-1]
    
    idx_trains = list()
    nrow = idx_val[0]
    tick = nrow//nclients
    for i in range(0, nclients):
        idx_trains.append([tick*i, tick*(i+1)])
    idx_val[0] = tick*nclients
    
    train_datasets = list()
    for f, t in idx_trains:
        print(f'split from {f} to {t}')
        train_dataset = tf.data.Dataset.from_tensor_slices((dataset['x'][dataset.index[f:t]].values.tolist(), 
                                                            dataset['y'][dataset.index[f:t]].values.tolist()))
        train_datasets.append(train_dataset)
    print(f'split from {idx_val[0]} to {idx_val[1]}')
    test_dataset = tf.data.Dataset.from_tensor_slices((dataset['x'][dataset.index[idx_val[0]:idx_val[1]]].values.tolist(), 
                                                       dataset['y'][dataset.index[idx_val[0]:idx_val[1]]].values.tolist()))
    return train_datasets, test_dataset
raw_train_datasets, raw_test_dataset = create_fl_dataset(dataset)

In [None]:
# BATCH_SIZE = 32
# SHUFFLE_BUFFER = 100
# NUM_EPOCHS = 5
# for i in range(0, len(train_datasets)):
#     train_datasets[i] = train_datasets[i].repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
# test_dataset = test_dataset.batch(BATCH_SIZE)

In [None]:
BATCH_SIZE = 32
SHUFFLE_BUFFER = 1000
NUM_EPOCHS = 1
def client_fn(client_id):
    return raw_train_datasets[client_id].repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
#     return train_datasets[client_id]

client_data = tff.simulation.ClientData.from_clients_and_fn(
                client_ids=range(0, len(raw_train_datasets)),
                create_tf_dataset_for_client_fn=client_fn)
client_data = [client_data.create_tf_dataset_for_client(x) for x in range(0, len(raw_train_datasets))]

test_dataset = raw_test_dataset.batch(BATCH_SIZE)

In [None]:
sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(test_dataset)))

In [None]:
IMG_SHAPE = (375, 4)
# Create Deep and Wide CNN
def create_model(nclass, img_shape=IMG_SHAPE):
    img_input = tf.keras.Input(shape=img_shape+(1, ))
    features1 = tf.keras.layers.Conv2D(32, (1, 1), activation='relu')(img_input)
    features1 = tf.keras.layers.Flatten()(features1)

    features2 = tf.keras.layers.Conv2D(32, (1, 2), activation='relu')(img_input)
    features2 = tf.keras.layers.Flatten()(features2)

    features3 = tf.keras.layers.Conv2D(32, (1, 4), activation='relu')(img_input)
    features3 = tf.keras.layers.Flatten()(features3)

    features4 = tf.keras.layers.Conv2D(32, (2, 2), activation='relu')(img_input)
    features4 = tf.keras.layers.MaxPooling2D((2, 2), strides=(1, 1))(features4)
    
    features5 = tf.keras.layers.Conv2D(32, (2, 2), activation='relu')(features4)

    features4 = tf.keras.layers.Flatten()(features4)
    features5 = tf.keras.layers.Flatten()(features5)

    x = tf.keras.layers.concatenate([features1, features2, features3, features4, features5])

    pred = tf.keras.layers.Dense(nclass, kernel_initializer='zeros')(x)

    model = tf.keras.Model(inputs=[img_input],
                           outputs=[pred])
    return model

In [None]:
def save_ckpt(state, metrics, path='./fl_ckpt'):
    keras_model = create_model(len(idx2lab.keys()))
    keras_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                        metrics=[tf.keras.metrics.CategoricalAccuracy()])
    tff.learning.assign_weights_to_keras_model(keras_model, state.model)
    keras_model.save(path)
    
    with open(os.path.join(path, 'fl_metrics.pickle'), 'wb') as f:
        pickle.dump(metrics, f)

        
def load_ckpt(state, path='./fl_ckpt'):
    keras_model = tf.keras.models.load_model(path)
    state = tff.learning.state_with_new_model_weights(
              state,
              trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
              non_trainable_weights=[v.numpy() for v in keras_model.non_trainable_weights])

    with open(os.path.join(path, 'fl_metrics.pickle'), 'rb') as f:
        metrics = pickle.load(f)
    
    return state, metrics

In [None]:
def model_fn():
    # We _must_ create a new model here, and _not_ capture it from an external
    # scope. TFF will call this within different graph contexts.
    keras_model = create_model(len(idx2lab.keys()))
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=test_dataset.element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
iterative_process = tff.learning.build_federated_averaging_process(
        model_fn,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

In [None]:
str(iterative_process.initialize.type_signature)

In [None]:
state = iterative_process.initialize()

In [None]:
evaluation = tff.learning.build_federated_evaluation(model_fn)

In [None]:
CKPT_PATH = f'./ckpt/fl'
NUM_ROUNDS = 10
start_round = 0
loss = list()
accuracy = list()
val_loss = list()
val_accuracy = list()
stime = time.time()

if os.path.exists(CKPT_PATH):
    state, metrics = load_ckpt(state, CKPT_PATH)
    start_round = metrics[0]
    loss = metrics[1]
    accuracy = metrics[2]
    val_loss = metrics[3]
    val_accuracy = metrics[4]

for round_num in range(start_round+1, NUM_ROUNDS+1):
    state, metrics = iterative_process.next(state, client_data)
    val_metrics = evaluation(state.model, [test_dataset])
    loss.append(metrics['train']['loss'])
    accuracy.append(metrics['train']['sparse_categorical_accuracy'])
    val_loss.append(val_metrics['loss'])
    val_accuracy.append(val_metrics['sparse_categorical_accuracy'])
    print((f'[{int(time.time()-stime)}] round: {round_num:2d}, '
           f'metrics: {metrics["train"]}, '
           f'val_metrics: {val_metrics}'))
    save_ckpt(state, [round_num, loss, accuracy, val_loss, val_accuracy], CKPT_PATH)

In [None]:
# Draw learning curves chart
acc = accuracy
val_acc = val_accuracy
loss = loss
val_loss = val_loss

fig2 = plt.figure(figsize=(8, 8))
ax1 = fig2.add_subplot(2, 1, 1)
ax1.plot(acc, label='Training Accuracy')
ax1.plot(val_acc, label='Validation Accuracy')
ax1.legend(loc='lower right')
ax1.set_ylabel('Accuracy')
ax1.set_ylim([0, 1])
ax1.set_title('Training and Validation Accuracy')

ax2 = fig2.add_subplot(2, 1, 2)
ax2.plot(loss, label='Training Loss')
ax2.plot(val_loss, label='Validation Loss')
ax2.legend(loc='upper right')
ax2.set_ylabel('Cross Entropy')
ax2.set_ylim([0, max(ax2.get_ylim())])
ax2.set_title('Training and Validation Loss')
ax2.set_xlabel('epoch')