In [2]:
import nest_asyncio

nest_asyncio.apply()

import collections
import json
from collections import OrderedDict
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow import TensorSpec
from tensorflow.data import Dataset
from tensorflow_federated.python.simulation.datasets import ClientData

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()

b'Hello, World!'

**We start by downloading the MNIST dataset provided by TensorFlow Federated. This data includes the client_ids (i.e. the ID of the client that generated the handwritten digit), allowing for a simulation of the "only positive labels" setting.**

In [3]:
train, test = tff.simulation.datasets.emnist.load_data()

In [4]:
# total number of clients 
len(train.client_ids)

3383

In [5]:
train.element_type_structure

OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)),
             ('pixels',
              TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

**We then convert the MNIST dataset to a federated dataset in which each client only has samples from a single class. This is in order to simulate the "only positive labels" setting in which each client only has access to positives from one class. We choose the most frequent label per client dataset (taken from *most_frequent_labels.json*) to be the positive label for the client. We keep in mind that the overall problem setting is still that of multi-class classification.**

In [6]:
f = open('most_frequent_labels.json')
d = json.load(f)
d2 = {}
for k, v in d.items(): 
    d2[train.client_ids[int(k)]] = v

In [7]:
BATCH_SIZE = 500

def preprocess(dataset):
    def helper(element):
        x=tf.reshape(element['pixels'], [-1, 784])
        y=tf.reshape(element['label'], [-1, 1])
        return collections.OrderedDict(x=x,y=y)

    return dataset.batch(BATCH_SIZE).map(helper)

In [8]:
def aux(client_id):
    dataset = train.create_tf_dataset_for_client(client_id)
    out = preprocess(dataset)
    out = next(iter(out))
    xs = out['x']
    ys = out['y']
    
    y_ind = tf.where(ys == d2[client_id]) 
    # d2[client_id] corresponds to the most frequent label in that client dataset
    gathered_ys = tf.gather_nd(ys, y_ind)
    
    x_ind = y_ind[:, 0]
    gathered_xs = tf.gather(xs, x_ind)
    
    return OrderedDict([('x', gathered_xs), ('y', gathered_ys)])

In [9]:
def dt_fn(client_id): 
    store = aux(client_id)
    return Dataset.from_tensor_slices({'x': store['x'], 'y': store['y']})

In [10]:
cd = ClientData.from_clients_and_tf_fn(train.client_ids, dt_fn) 

In [24]:
NUM_CLIENTS = len(cd.client_ids)
NUM_CLIENTS

3383

In [14]:
def preprocess2(dataset):
    def batch_format_fn(element):
        """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
        return collections.OrderedDict(
            x=tf.reshape(element['x'], [-1, 784]),
            y=tf.reshape(element['y'], [-1, 1]))

    return dataset.batch(BATCH_SIZE).map(batch_format_fn)

**We now make additional arrangements to simulate a federated training environment. In a typical federated training scenario, we deal with a potentially very large population of user devices. Of these devices, only a fraction may be available for training (mobile phone connected to internet, charging, etc.) at a given point in time. To simulate this volatility, we sample a random subset of the clients to be involved in each round of training.**

In [15]:
def make_federated_data(client_data, client_ids):
    return [preprocess2(client_data.create_tf_dataset_for_client(x))
      for x in client_ids]

In [17]:
def create_keras_model():
    return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
    ])

In [18]:
def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
      keras_model,
      input_spec=OrderedDict([('x', TensorSpec(shape=(None, 28*28), 
                                               dtype=tf.float32, name=None)),
                             ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))]),
                              
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [19]:
training_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

In [20]:
print(training_process.initialize.type_signature.formatted_representation())

( -> <
  global_model_weights=<
    trainable=<
      float32[784,10],
      float32[10]
    >,
    non_trainable=<>
  >,
  distributor=<>,
  client_work=<>,
  aggregator=<
    value_sum_process=<>,
    weight_sum_process=<>
  >,
  finalizer=<
    int64
  >
>@SERVER)


In [21]:
train_state = training_process.initialize()

**The volatility simulation happens here: at each round of training, we choose a random integer N between 100 and 300 as the number of participants in the training round. We then construct the federated train data for that round by choosing N client_ids.**

In [31]:
chosen_client_ids

array([1435, 2830, 3113, 3130, 2753, 3108, 2826,   86, 3115,  872, 2059,
       2818,  307, 1104, 2080, 1206, 1152, 2854,  275, 1198, 1578, 1395,
       1208, 2492, 2536, 2334, 1304,  637,  770,   94, 3298, 1899, 3341,
       1904, 2344, 1352, 3091,  607, 1434, 2498, 1272,  180, 2371, 3052,
       3332,  749, 2187, 1020, 2646, 2937, 1645,  843, 2744, 1552, 3224,
        925, 2197, 1134,   25, 1488,  956, 1913, 2934, 1141, 1469, 1619,
       2721, 1896,  928, 3300, 1531, 2811, 2169, 1350,  469, 2335,  525,
       1863, 1720, 1176,  591, 2089, 2322,  207, 2827,  166, 2159, 1153,
       2783, 1910, 2860,  216,   24,   67, 3026, 1007, 2282, 1740, 2022,
        291, 1750, 1022, 2749, 2775, 2603, 2080, 3083, 1640, 2260, 1162,
       1718, 1445,  623,  770, 1563, 1241, 2711,  821,  307, 1198, 1172,
       2997, 1565, 2595,  807, 2121,  297, 2967,  730, 1970, 2924, 1795,
       1823, 2825,  429,  199, 1447, 1085, 2389,  865, 3116, 2850, 1954,
       1112, 2081, 2693, 2280])

In [29]:
cd.client_ids[chosen_client_ids]

TypeError: only integer scalar arrays can be converted to a scalar index

In [None]:
N = np.random.randint(100, 300)
chosen_client_ids = np.random.randint(0, NUM_CLIENTS, N)
federated_train_data = make_federated_data(cd, cd.client_ids)

In [142]:
result = training_process.next(train_state, federated_train_data)

In [145]:
train_state = result.state
train_metrics = result.metrics
print('round  1, metrics={}'.format(train_metrics))

round  1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.19669421), ('loss', 2.3025854), ('num_examples', 1210), ('num_batches', 100)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])


In [146]:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  result = training_process.next(train_state, federated_train_data)
  train_state = result.state
  train_metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, train_metrics))

round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42727274), ('loss', 2.7164285), ('num_examples', 1210), ('num_batches', 100)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.19669421), ('loss', 4.237716), ('num_examples', 1210), ('num_batches', 100)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42727274), ('loss', 3.0594237), ('num_examples', 1210), ('num_batches', 100)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  5, metrics=OrderedDict([('distributor', ()), ('clien