In [2]:
import nest_asyncio
nest_asyncio.apply()

import tensorflow as tf
import tensorflow_federated as tff
from tensorflow import TensorSpec
from tensorflow.data import Dataset
from tensorflow_federated.python.simulation.datasets import ClientData

In [3]:
train, test = tff.simulation.datasets.emnist.load_data()

Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



In [30]:
import json

f = open('most_frequent_labels.json')
d = json.load(f)
d2 = {}
for k, v in d.items(): 
    d2[train.client_ids[int(k)]] = v

In [31]:
BATCH_SIZE = 500

def preprocess(dataset):

  def helper(element):
    return (tf.reshape(element['pixels'], [-1, 28*28]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(helper)

In [33]:
def aux(client_id):
    dataset = train.create_tf_dataset_for_client(client_id)
    out = preprocess(dataset)
    out = next(iter(out))
    # xs = out['x']
    # ys = out['y']
    xs = out[0]
    ys = out[1]
    
    y_ind = tf.where(ys == d2[client_id]) 
    # d2[client_id] corresponds to the most frequent label in that client dataset
    gathered_ys = tf.gather_nd(ys, y_ind)
    
    x_ind = y_ind[:, 0]
    gathered_xs = tf.gather(xs, x_ind)
    
    return (gathered_xs, gathered_ys)
    # return OrderedDict([('x', gathered_xs), ('y', gathered_ys)])

In [51]:
def dt_fn(client_id): 
    store = aux(client_id)
    return Dataset.from_tensor_slices((store[0], store[1]))

In [52]:
cd = ClientData.from_clients_and_tf_fn(train.client_ids, dt_fn) 

In [56]:
def make_federated_data(client_data, client_ids):
    return [preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids]

In [57]:
client_ids = sorted(train.client_ids)[:10]
federated_train_data = [preprocess(train.create_tf_dataset_for_client(x))
  for x in client_ids
]

In [6]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
    keras_model,
    input_spec=federated_train_data[0].element_spec,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [7]:
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

In [8]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

In [9]:
@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

In [10]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

In [11]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

In [12]:
tf_dataset_type = tff.SequenceType(model_fn().input_spec)
tf_dataset_type

SequenceType(StructType([TensorType(tf.float32, [None, 784]), TensorType(tf.int32, [None, 1])]) as tuple)

In [32]:
str(tf_dataset_type)

'<float32[?,784],int32[?,1]>*'

In [13]:
model_weights_type = server_init.type_signature.result
model_weights_type

StructType([TensorType(tf.float32, [784, 10]), TensorType(tf.float32, [10])]) as list

In [14]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

In [15]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

In [16]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

In [17]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  
  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

In [18]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [19]:
federated_algorithm.initialize.type_signature

FunctionType(None, FederatedType(StructType([TensorType(tf.float32, [784, 10]), TensorType(tf.float32, [10])]) as list, PlacementLiteral('server'), True))

In [20]:
str(federated_algorithm.initialize.type_signature)

'( -> <float32[784,10],float32[10]>@SERVER)'

In [21]:
federated_algorithm.next.type_signature

FunctionType(StructType([('server_weights', FederatedType(StructType([TensorType(tf.float32, [784, 10]), TensorType(tf.float32, [10])]) as list, PlacementLiteral('server'), True)), ('federated_dataset', FederatedType(SequenceType(StructType([TensorType(tf.float32, [None, 784]), TensorType(tf.int32, [None, 1])]) as tuple), PlacementLiteral('clients'), False))]) as OrderedDict, FederatedType(StructType([TensorType(tf.float32, [784, 10]), TensorType(tf.float32, [10])]) as list, PlacementLiteral('server'), True))

In [22]:
str(federated_algorithm.next.type_signature)

'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

In [23]:
central_emnist_test = test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

In [24]:
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

In [61]:
server_state = federated_algorithm.initialize()
evaluate(server_state)



In [60]:
for round in range(1):
  server_state = federated_algorithm.next(server_state, federated_train_data)

In [27]:
evaluate(server_state)

