In [None]:
!pip install --quiet --upgrade tensorflow-federated

In [2]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np
import math
import random
from matplotlib import pyplot as plt

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [4]:
def preprocess(dataset):
  def format_fn(element):
    return (tf.reshape(element['pixels'], [-1, 784]),
            tf.reshape(element['label'], [-1, 1]))

  return dataset.map(format_fn)

In [5]:
NUM_CLIENTS = 50
NUM_CLASSES = 10

In [6]:
client_ids = sorted(emnist_train.client_ids[:NUM_CLIENTS * 10])
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]
del(emnist_train)

In [9]:
import functools
combined_dataset = functools.reduce(tf.data.Dataset.concatenate, federated_train_data)
combined_dataset = list(combined_dataset)
random.shuffle(combined_dataset)

In [None]:
dataset_size = len(combined_dataset)

In [21]:
# The distribution here is simple.
# All clients will equally have the same number of samples which is dataset_size divided by the number of total clients.
def create_clients_datasets(dataset):
  clients_datasets = []
  lower_bound = 0
  upper_bound = dataset_size // NUM_CLIENTS
  for i in range(NUM_CLIENTS):
    clients_datasets.append(dataset[lower_bound:upper_bound])
    lower_bound = upper_bound
    upper_bound += dataset_size // NUM_CLIENTS
  return clients_datasets

clients_datasets = create_clients_datasets(combined_dataset)

In [None]:
import numpy as np
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt

# Convert tensors to NumPy arrays and flatten them
for i in range(NUM_CLIENTS):
  flat_arrays = [tensor[1].numpy().flatten() for tensor in clients_datasets[i]]

  # Concatenate the flattened arrays into a single array
  concatenated_array = np.concatenate(flat_arrays)

  # Estimate the kernel density
  kde = gaussian_kde(concatenated_array)

  # Generate a range of values for the x-axis
  x_vals = np.linspace(concatenated_array.min(), concatenated_array.max(), num=1000)

  # Evaluate the kernel density at each x value
  y_vals = kde(x_vals)

  # Plot the kernel density estimate
  plt.plot(x_vals, y_vals)

plt.xlabel('Label')
plt.ylabel('Density')
plt.title('Kernel Density Estimate for 50 Clients')
plt.show()

In [7]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

In [13]:
# Hyperparameters
NUM_CLIENTS = 50
# C (clipping norm)
clipping_norm = 50
# ε (privacy budget)
epsilon = 50
# T (communication rounds)
comms_round = 50
# K clients for each round
USERS_PER_ROUND = 20
# δ here should be less than 1 over the total size of all datasets
delta = (1 / dataset_size) * 0.5

learning_rate = 0.01
momentum = 0.9
current_epoch_num = 0

In [None]:
c = math.sqrt(2.0 * math.log(1.25 / delta))

constant $c \geq \sqrt{2 \ln (1.25 / \delta)}$ contributes to both global and local noise scales

In [8]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(256, activation='relu', kernel_initializer=initializer),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

input_spec = (tf.TensorSpec(shape=(None, 784), dtype=tf.float32),
              tf.TensorSpec(shape=(None, 1), dtype=tf.int32))

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
def get_global_stddev():
  gamma = -1 * np.log(1 - (USERS_PER_ROUND / NUM_CLIENTS) + (USERS_PER_ROUND / NUM_CLIENTS) * np.exp((-1 * epsilon) / math.sqrt(USERS_PER_ROUND)))
  if (current_epoch_num + 1) > (epsilon / gamma):
    b = -1 * ((current_epoch_num + 1) / epsilon) * np.log(1 - (NUM_CLIENTS / USERS_PER_ROUND) + ((NUM_CLIENTS / USERS_PER_ROUND) * np.exp((-1 * epsilon) / (current_epoch_num + 1))))
    stddev = (2 * c * clipping_norm * math.sqrt((((current_epoch_num + 1) ** 2) / (b ** 2)) - USERS_PER_ROUND)) / ((dataset_size / NUM_CLIENTS) * USERS_PER_ROUND * epsilon)
  else:
    stddev = 0
  return stddev

The following formula corrsponds to the global noise standard deviation calculated in ``` get_global_stddev ```.



$\sigma_{\mathrm{D}}= \begin{cases}\frac{2 c C \sqrt{\frac{T^2}{b^2}-L^2 K}}{m K \epsilon} & T>\frac{\epsilon}{\gamma} \\ 0 & T \leq \frac{\epsilon}{\gamma}\end{cases}$

In [None]:
def clip_weight(weight, clipping_norm):
  weight = tf.cast(weight, tf.float32)  # Cast the weight tensor to float32
  max_value = tf.maximum(tf.constant(1.0, dtype=tf.float32), tf.divide(weight, tf.constant(clipping_norm, dtype=tf.float32)))
  return weight.assign(tf.divide(weight, max_value))

The following is the clipping mechanism used at each client locally after completing each training epoch. The calculation is being done in `clip_weight` method.

$w_i^{(t)}=w_i^{(t)} / \max \left(1, \frac{\left\|w_i^{(t)}\right\|}{C}\right)$

Each local noise scale is calculated as follows: $\sigma_{\mathrm{U}}=c L \Delta s_{\mathrm{U}} / \epsilon$. The calculation is being done below in `client_update` method.

In [24]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.

  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  client_weights = tf.nest.map_structure(lambda x: x.assign(clip_weight(x, clipping_norm)), client_weights)
  sensitivity = 2 * (clipping_norm) / (dataset_size / len(client_ids))
  stddev = (c * sensitivity * (current_epoch_num + 1)) / epsilon
  for var in client_weights:
    stddev = tf.cast(stddev, var.dtype)
    noise = tf.random.normal(shape=tf.shape(var), mean=0.0, stddev=stddev)
    var.assign_add(noise)
  return client_weights

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  stddev = get_global_stddev()
  for var in model_weights:
    stddev = tf.cast(stddev, var.dtype)
    noise = tf.random.normal(shape=tf.shape(var), mean=0.0, stddev=stddev)
    var.assign_add(noise)
  return model_weights

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
model_weights_type = server_init.type_signature.result


@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

In [None]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [25]:
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=['accuracy']
  )
  keras_model.set_weights(server_state)
  hst = keras_model.evaluate(central_emnist_test)
  return hst

In [None]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if len(device_name) > 0:
    print("Found GPU at: {}".format(device_name))
else:
    device_name = "/device:CPU:0"
    print("No GPU, using {}.".format(device_name))

with tf.device(device_name):
  server_state = federated_algorithm.initialize()

history = {
    'loss': [],
    'accuracy': []
}

In [22]:
clients_ids = [i for i in range(50)]

In [None]:
for round in range(comms_round):
  current_epoch_num = round + 1
  print(f'Round: {current_epoch_num}')
  selected_clients = random.sample(clients_ids, USERS_PER_ROUND)
  current_federated_train_data = [clients_datasets[client_id] for client_id in selected_clients]
  server_state = federated_algorithm.next(server_state, current_federated_train_data)
  loss, accuracy = evaluate(server_state)
  history['loss'].append(loss)
  history['accuracy'].append(accuracy)

In [None]:
import pickle
with open(f'history_c{clipping_norm}_e{epsilon}_k{USERS_PER_ROUND}.pickle', 'wb') as file:
  pickle.dump(history, file)

In [None]:
plt.plot(history['loss'])
plt.title('Evaluation')
plt.ylabel('loss')
plt.xlabel('epochs')
plt.legend(['loss'], loc='upper left')
plt.show()

In [None]:
# Plotting Loss and Accuracy on top of each other
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))
fig.suptitle('Model Performance', fontsize=14)

# Plot Loss
ax1.plot(history['loss'])
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')
ax1.legend(['Loss'], loc='upper right')

# Plot Accuracy
ax2.plot(history['accuracy'])
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Epoch')
ax2.legend(['Accuracy'], loc='lower right')

# Adjust the spacing between subplots
plt.subplots_adjust(hspace=0.5)

# Show the plot
plt.show()

In [None]:
model = create_keras_model()
model.set_weights(server_state)
model.save_weights(f'/content/FLDP_iid_c{clipping_norm}_e{epsilon}_k{USERS_PER_ROUND}.h5')