In [None]:
!pip install --quiet --upgrade tensorflow-federated

In [None]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np
import math
import random
from matplotlib import pyplot as plt

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [None]:
def preprocess(dataset):
  def format_fn(element):
    return (tf.reshape(element['pixels'], [-1, 784]),
            tf.reshape(element['label'], [-1, 1]))

  return dataset.map(format_fn)

In [None]:
NUM_CLIENTS = 50
NUM_CLASSES = 10

In [None]:
def import_tff_mnist(n_samples, emnist_train):
  client_datasets = []
  for client_id in emnist_train.client_ids[:n_samples]:
    client_dataset = emnist_train.create_tf_dataset_for_client(client_id)
    client_datasets.extend(preprocess(client_dataset))
  del(emnist_train)
  return client_datasets

client_datasets = import_tff_mnist(NUM_CLIENTS * 10, emnist_train)

In [None]:
non_iid_strength_factor = 3

In [None]:
import random

random.shuffle(client_datasets)

# After we import the dataset, we are going to sort each (sample, label) by the label number in ascending order (from 0 to 9).
client_datasets = sorted(client_datasets, key=lambda x: int(x[1]))

# Then, we need to separate the samples by creating a dictionary of 10 different keys (0 -> 9) which are the labels.
# Each key has a value which is a list of all samples belong to the label that the key represents.
sorted_by_label = {str(k): [] for k in range(10)}
for sample in client_datasets:
  sorted_by_label[str(int(sample[1]))].append(sample)

# This dictionary is going to hold the dataset separated by clients.
# The dictionary will be filled in the for loop below where we distribute the samples between the clients.
clients_datasets_dict = {k: [] for k in range(NUM_CLIENTS)}

# The following dictionary contains 10 different keys (0 -> 9) which are the labels.
# Each value is a list of size equals to NUM_CLIENTS.
# Each list contains random picks along a range of "label_list" length (indices in ascending order). Some of theses indices will be dropped depening on the non_iid_strength_factor as we will see in the loop below.
# This means that each client "i" will have chunk of samples from label_list[i-1] to label_list[i] for each label.
lables_indices = {k: sorted(random.sample(range(len(label_list)), (NUM_CLIENTS - 1))) for label, label_list, k in zip(sorted_by_label.keys(), sorted_by_label.values(), range(NUM_CLASSES))}

# The dictionary here is meant to track the next available label index for the next client.
index_tracker = {label: 0 for label, _ in sorted_by_label.items()}

# Starting the distributing process by first looping over the clients.
for i in range(NUM_CLIENTS):
  # client_execluded_classes holds the labels that the i-th client will be deprived from
  client_execluded_classes = random.sample(range(NUM_CLASSES), non_iid_strength_factor)
  for label, label_list in sorted_by_label.items():
    # If the outer loop is not at the last client:
    if i != (NUM_CLIENTS - 1):
      # Pop the next index whatever.
      label_index = lables_indices[int(label)].pop(0)
      # If the i-th client can't have the current label in the inner loop, continue.
      if int(label) in client_execluded_classes:
        continue
    # If the outer loop is at the last client, give me the last index.
    else:
      label_index = len(label_list)
    # Assign the next chunck to the i-th client.
    clients_datasets_dict[i].extend(label_list[index_tracker[label]:label_index])
    random.shuffle(clients_datasets_dict[i])
    # Keep track of the next index of the current label for the next client.
    index_tracker[label] = label_index

In [None]:
import numpy as np
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt

# Convert tensors to NumPy arrays and flatten them
for i in random.sample(range(NUM_CLIENTS), k=5):
  flat_arrays = [tensor[1].numpy().flatten() for tensor in clients_datasets_dict[i]]

  # Concatenate the flattened arrays into a single array
  concatenated_array = np.concatenate(flat_arrays)

  # Estimate the kernel density
  kde = gaussian_kde(concatenated_array)

  # Generate a range of values for the x-axis
  x_vals = np.linspace(concatenated_array.min(), concatenated_array.max(), num=1000)

  # Evaluate the kernel density at each x value
  y_vals = kde(x_vals)

  # Plot the kernel density estimate
  plt.plot(x_vals, y_vals, label=f'Client: {i}')

plt.xlabel('Label')
plt.ylabel('Density')
plt.title('KDE for Random Clients Data - 3 Labels Exclusion')
plt.legend()
plt.show()

In [None]:
dataset_size = sum([len(clients_datasets_dict[i]) for i in range(NUM_CLIENTS)])

In [None]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

In [None]:
# Hyperparameters
NUM_CLIENTS = 50
# C (clipping norm)
clipping_norm = 50
# ε (privacy budget)
epsilon = 50
# T (communication rounds)
comms_round = 50
# K clients for each round
USERS_PER_ROUND = 20
# δ here should be less than 1 over the total size of all datasets
delta = (1 / dataset_size) * 0.5

learning_rate = 0.01
momentum = 0.9
current_epoch_num = 0

In [None]:
c = math.sqrt(2.0 * math.log(1.25 / delta))

constant $c \geq \sqrt{2 \ln (1.25 / \delta)}$ contributes to both global and local noise scales

In [None]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(256, activation='relu', kernel_initializer=initializer),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

input_spec = (tf.TensorSpec(shape=(None, 784), dtype=tf.float32),
              tf.TensorSpec(shape=(None, 1), dtype=tf.int32))

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
def get_global_stddev():
  gamma = -1 * np.log(1 - (USERS_PER_ROUND / NUM_CLIENTS) + (USERS_PER_ROUND / NUM_CLIENTS) * np.exp((-1 * epsilon) / math.sqrt(USERS_PER_ROUND)))
  if (current_epoch_num + 1) > (epsilon / gamma):
    b = -1 * ((current_epoch_num + 1) / epsilon) * np.log(1 - (NUM_CLIENTS / USERS_PER_ROUND) + ((NUM_CLIENTS / USERS_PER_ROUND) * np.exp((-1 * epsilon) / (current_epoch_num + 1))))
    stddev = (2 * c * clipping_norm * math.sqrt((((current_epoch_num + 1) ** 2) / (b ** 2)) - USERS_PER_ROUND)) / ((dataset_size / NUM_CLIENTS) * USERS_PER_ROUND * epsilon)
  else:
    stddev = 0
  return stddev

The following formula corrsponds to the global noise standard deviation calculated in ``` get_global_stddev ```.



$\sigma_{\mathrm{D}}= \begin{cases}\frac{2 c C \sqrt{\frac{T^2}{b^2}-L^2 K}}{m K \epsilon} & T>\frac{\epsilon}{\gamma} \\ 0 & T \leq \frac{\epsilon}{\gamma}\end{cases}$

In [None]:
def clip_weights(weight, clipping_norm):
  weight = tf.cast(weight, tf.float32)  # Cast the weight tensor to float32
  max_value = tf.maximum(tf.constant(1.0, dtype=tf.float32), tf.divide(weight, tf.constant(clipping_norm, dtype=tf.float32)))
  return weight.assign(weight / max_value)

The following is the clipping mechanism used at each client locally after completing each training epoch. The calculation is being done in `clip_weight` method.

$w_i^{(t)}=w_i^{(t)} / \max \left(1, \frac{\left\|w_i^{(t)}\right\|}{C}\right)$

Each local noise scale is calculated as follows: $\sigma_{\mathrm{U}}=c L \Delta s_{\mathrm{U}} / \epsilon$. The calculation is being done below in `client_update` method.

In [None]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.

  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  client_weights = tf.nest.map_structure(lambda x: x.assign(clip_weights(x, clipping_norm)), client_weights)
  sensitivity = 2 * (clipping_norm) / (dataset_size / NUM_CLIENTS)
  stddev = (c * sensitivity * (current_epoch_num + 1)) / epsilon
  for var in client_weights:
    stddev = tf.cast(stddev, var.dtype)
    noise = tf.random.normal(shape=tf.shape(var), mean=0.0, stddev=stddev)
    var.assign_add(noise)
  return client_weights

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  stddev = get_global_stddev()
  for var in model_weights:
    stddev = tf.cast(stddev, var.dtype)
    noise = tf.random.normal(shape=tf.shape(var), mean=0.0, stddev=stddev)
    var.assign_add(noise)
  return model_weights

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
model_weights_type = server_init.type_signature.result

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

In [None]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [None]:
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=['accuracy']
  )
  keras_model.set_weights(server_state)
  hst = keras_model.evaluate(central_emnist_test)
  return hst

In [None]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if len(device_name) > 0:
    print("Found GPU at: {}".format(device_name))
else:
    device_name = "/device:CPU:0"
    print("No GPU, using {}.".format(device_name))

with tf.device(device_name):
  server_state = federated_algorithm.initialize()

history = {
    'loss': [],
    'accuracy': []
}

In [None]:
clients_ids = [i for i in range(50)]

In [None]:
for round in range(comms_round):
  current_epoch_num = round + 1
  print(f'Round: {current_epoch_num}')
  selected_clients = random.sample(clients_ids, USERS_PER_ROUND)
  current_federated_train_data = [clients_datasets_dict[client_id] for client_id in selected_clients]
  server_state = federated_algorithm.next(server_state, current_federated_train_data)
  loss, accuracy = evaluate(server_state)
  history['loss'].append(loss)
  history['accuracy'].append(accuracy)

In [None]:
import pickle
with open(f'history_c{clipping_norm}_e{epsilon}_k{USERS_PER_ROUND}_sf{non_iid_strength_factor}.pickle', 'wb') as file:
  pickle.dump(history, file)

In [None]:
plt.plot(history['loss'])
plt.title('Evaluation')
plt.ylabel('loss')
plt.xlabel('epochs')
plt.legend(['loss'], loc='upper left')
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Plotting Loss and Accuracy on top of each other
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))
fig.suptitle('Model Performance', fontsize=14)

# Plot Loss
ax1.plot(history['loss'])
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')
ax1.legend(['Loss'], loc='upper right')

# Plot Accuracy
ax2.plot(history['accuracy'])
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Epoch')
ax2.legend(['Accuracy'], loc='lower right')

# Adjust the spacing between subplots
plt.subplots_adjust(hspace=0.5)

# Show the plot
plt.show()

In [None]:
model = create_keras_model()
model.set_weights(server_state)
model.save_weights(f'/content/FLDP_non_iid_c{clipping_norm}_e{epsilon}_k{USERS_PER_ROUND}_sf{non_iid_strength_factor}.h5')