# Defense Idea implementation

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from matplotlib import pyplot as plt

import collections
from tqdm import tqdm
import random

In [2]:
n_clients = 100
n_test_clients = 300

n_train_dataset_epochs = 6
n_test_dataset_epochs = 3
batch_size_train = 20
batch_size_test = 20
shuffle_buffer = 100
prefetch_buffer = 10

n_train_epochs = 20

client_learning_rate = 0.02
server_learning_rate = 1


hidden_units = 256
dropout = 0.1

mal_users_percentage = 0.2
# todo: could also be a list of values
target_value = 3
poisoned_value = 8

open('/home/nikos/msc-thesis/tmp/losses', 'w').close()

In [3]:
def batch_format(element):
    # flatten the images
    return collections.OrderedDict(
        x = tf.reshape(element['pixels'], [-1, 28, 28]),
        y = tf.reshape(element['label'], [-1, 1]))

def preprocess(dataset, train):
    if train == True:
        dataset = dataset.repeat(n_train_dataset_epochs)
    else:
        dataset = dataset.repeat(n_test_dataset_epochs)
        
    dataset = dataset.shuffle(shuffle_buffer, seed = 1)
    if train == True:
        dataset = dataset.batch(batch_size_train)
    else: 
        dataset = dataset.batch(batch_size_test)
        
    dataset = dataset.map(batch_format)
    dataset = dataset.prefetch(prefetch_buffer)

    return dataset

In [4]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

preprocessed_example_dataset = preprocess(example_dataset, True)

In [5]:
def poison_dataset(dataset, target_honest, target_mal):
    def map_fn(element):
        labels = element['y']
        # condition where label equals to target_honest
        condition = tf.equal(labels, target_honest)
        # replace label with target_mal where condition is true
        labels = tf.where(condition, target_mal, labels)
        # replace 'y' in the element dict
        element['y'] = labels
        return element
    # apply the map function to the dataset
    poisoned_dataset = dataset.map(map_fn)
    return poisoned_dataset

In [6]:
def make_federated_data(client_data, client_ids, target_value, poisoned_value, train, mal_users_percentage=0):
    fed_data = []
    
    # Iterate over each client
    for id in client_ids:
        # Preprocess the client's data
        preprocessed_dataset = preprocess(client_data.create_tf_dataset_for_client(id), train)
        
        # Generate a random number. If the number is less than 'mal_users_percentage', this client's data will be poisoned.
        prob = random.random()
        
        # Check if this client's data should be poisoned
        if prob < mal_users_percentage:
            # Poison the data by changing the labels of the target class
            preprocessed_dataset = poison_dataset(preprocessed_dataset, target_value, poisoned_value)
        
        # Add the (possibly poisoned) client data to the federated data
        fed_data.append(preprocessed_dataset)
    
    return fed_data

In [7]:
class SpecificClassRecall(tf.keras.metrics.Metric):
    def __init__(self, class_id, name='specific_class_recall', **kwargs):
        super(SpecificClassRecall, self).__init__(name=name, **kwargs)
        self.class_id = class_id
        self.recall = tf.keras.metrics.Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        class_id_true = tf.equal(y_true, self.class_id)
        class_id_pred = tf.equal(tf.argmax(y_pred, axis=1), self.class_id)
        self.recall.update_state(class_id_true, class_id_pred, sample_weight)

    def result(self):
        return self.recall.result()

    def reset_states(self):
        self.recall.reset_states()
    
    def get_config(self):
        config = super(SpecificClassRecall, self).get_config()
        config.update({"class_id": self.class_id})
        return config

In [8]:
def create_model():
      return tf.keras.models.Sequential([
      tf.keras.layers.Reshape(input_shape=(28,28,1), target_shape=(28,28,1)),
      tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu'),
      tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dropout(dropout),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

In [9]:
def mnist_model():
    keras_model = create_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec = preprocessed_example_dataset.element_spec,
        loss = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy(), SpecificClassRecall(class_id = 3)])     

In [10]:
whimsy_model = mnist_model()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

In [11]:
def append_loss(total_loss):
    clients_loss.append(total_loss.numpy().item())
    return total_loss

In [12]:
@tf.function
def client_fake_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)


  # Initialize a variable to accumulate the total loss.
  total_loss = 0.0
  
  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Add the current batch loss to the total loss.
    total_loss += outputs.loss
    
    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)
    
    
  # total_loss = tf.py_function(func=append_loss, inp=[total_loss], Tout=tf.float32)
  # total_loss.set_shape(())

  # clients_loss.append(total_loss.numpy().item())
  # clients_loss.append(tf.print(total_loss))
  tf.print(total_loss, output_stream = "file:///home/nikos/msc-thesis/tmp/losses")
  return tf.nest.map_structure(tf.identity, (client_weights, total_loss))

In [13]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer, threshold):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  default_client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)


  # Initialize a variable to accumulate the total loss.
  total_loss = 0.0
  
  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)


    # Add the current batch loss to the total loss.
    total_loss += outputs.loss
    
    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)
    
  if total_loss > tff.federated_value(threshold, tff.SERVER):
    client_weights = default_client_weights

  # total_loss = tf.py_function(func=append_loss, inp=[total_loss], Tout=tf.float32)
  # total_loss.set_shape(())


  return tf.nest.map_structure(tf.identity, (client_weights, total_loss))

In [14]:
@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

In [15]:
@tff.tf_computation
def server_init():
  model = mnist_model()
  return model.trainable_variables

In [16]:
model_weights_type = server_init.type_signature.result

In [17]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

In [18]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_fake_update_fn(tf_dataset, server_weights):
  model = mnist_model()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=client_learning_rate)
  return client_fake_update(model, tf_dataset, server_weights, client_optimizer)

In [19]:
# @tff.tf_computation(tf_dataset_type, model_weights_type, tff.SequenceType(tf.float32))
@tff.tf_computation()
def client_update_fn(tf_dataset, server_weights, threshold):
  model = mnist_model()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=client_learning_rate)
  return client_update(model, tf_dataset, server_weights, client_optimizer, threshold)

In [20]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = mnist_model()
  return server_update(model, mean_client_weights)

In [21]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

In [22]:
@tff.tf_computation
def dp_fn(x):
    # TODO: apply ldp
  return tf.square(x)

In [23]:
def read_last_n_lines(file_path, n):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    lines = [line.strip() for line in lines[-n:]]
    return lines

In [24]:
@tff.tf_computation(tf.float32)
def threshold_fn(threshold):
    return threshold

In [25]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)


  # threshold = tff.cast(0.0, dtype=tf.float32)
  threshold = 0.0
  threshold_res = tff.federated_map(threshold_fn, threshold)
  threshold_at_client = tff.federated_broadcast(threshold_res)
  
 # Each client computes their updated weights and loss.
  client_weights_and_losses = tff.federated_map(
      client_fake_update_fn, (federated_dataset, server_weights_at_client, threshold_at_client))


  losses = read_last_n_lines("/home/nikos/msc-thesis/tmp/losses", n_clients)
  
  print(losses[0:10])
  # Split the weights and losses
  client_weights = client_weights_and_losses[0]
  client_losses = client_weights_and_losses[1]
  print(client_losses)
  
  
#   # Get the indices of the n clients with the highest losses
#   top_n_losses_indices = tf.argsort(client_losses)[-5:]

#   # Replace the weights of the top n losses with the previous round's mean weights
#   for i in top_n_losses_indices:
#       client_weights[i] = server_weights

  # for loss in dp_losses:
  #   tf.print(loss)
  
  # TODO: algorithm for eliminating clients based on loss 
  
  # print(client_losses)

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)
  mean_loss = tff.federated_mean(client_losses)
  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

AttributeError: module 'tensorflow_federated' has no attribute 'cast'

In [None]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

clients = random.sample(emnist_train.client_ids, 100)

federated_train_data = make_federated_data(emnist_train, clients, 0, 0, train=True)      


In [None]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test, train=False)

In [None]:
def evaluate(server_state):
  keras_model = create_model()
  keras_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test.map(lambda element: (element['x'], element['y'])))


In [None]:
server_state = federated_algorithm.initialize()
evaluate(server_state)

In [None]:
for epoch in tqdm(range(n_train_epochs), position = 0, leave = True):
    server_state = federated_algorithm.next(server_state, federated_train_data)    
    


In [None]:
evaluate(server_state)

In [None]:
print(clients_loss)

In [None]:
# # load the datasets that are going to be used
# emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

# # lists to hold the metrics that we want to compute
# accs = []
# losses = []
# class_recall = []

# example_dataset = emnist_train.create_tf_dataset_for_client(
#     emnist_train.client_ids[0])

# preprocessed_example_dataset = preprocess(example_dataset, True)
 
     
# # build the process to have the model's architecture
# # evaluation_process = tff.learning.algorithms.build_fed_eval(mnist_model)

# # initialize the state of the evaluation
# sample_test_clients = emnist_test.client_ids[0:n_test_clients]

# federated_test_data = make_federated_data(emnist_test, sample_test_clients, 0, 0, train=False)

# # fix the random clients so that they are the same for every model
# clients = []

# for i in range(n_train_epochs):
#     clients.append(random.sample(emnist_train.client_ids, n_clients))