## Install dependencies

In [None]:
!pip install transformers
!pip install datasets
!pip install flower

## Download the data from Google Drive

In [None]:
import gdown
url = "https://drive.google.com/drive/folders/13ijZBGIVAm-x93YpWKqIRe9alJa-O45K?usp=sharing"
gdown.download_folder(url)

## Import Libraries

In [None]:
from os import listdir
from os.path import isfile, join

import datetime
import time
import torch
import random
import re

In [None]:
import gc

import numpy as np
import pandas as pd
import tensorflow as tf

from collections import OrderedDict
from typing import List, Tuple

In [None]:
import flwr as fl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from flwr.common import Metrics
from torch.utils.data import DataLoader, random_split

In [None]:
# from emoji import demojize
from transformers import AutoModel, AutoTokenizer,BertForSequenceClassification, AdamW, BertConfig,get_linear_schedule_with_warmup

from tensorflow.keras.preprocessing.sequence import pad_sequences
from nltk.tokenize import TweetTokenizer
from sklearn.metrics import precision_recall_fscore_support, classification_report
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

> ### This is a special cleaning for SemEval english data only

In [None]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))


def normalizeToken(token):
    token = token.strip()
    lowercased_token = token.lower().strip()
    # print(token)
    if token != " ":
        if token.startswith("@"):
            return "@USER"
        elif lowercased_token.startswith("http") or lowercased_token.startswith("www"):
            return "HTTPURL"
        # elif len(token) == 1:
        #     return demojize(token)
        else:
            if token == "’":
                return "'"
            elif token == "…":
                return "..."
            else:
                return token


def normalizeTweet(tweet):
    tok = TweetTokenizer()
    tokens = tok.tokenize(tweet.replace("’", "'").replace("…", "..."))
    normTweet = " ".join([normalizeToken(token) for token in tokens])
    # print(normTweet)
    normTweet = normTweet.replace("cannot ", "can not ").replace("n't ", " n't ").replace("n 't ", " n't ").replace("ca n't", "can't").replace("ai n't", "ain't")
    normTweet = normTweet.replace("'m ", " 'm ").replace("'re ", " 're ").replace("'s ", " 's ").replace("'ll "," 'll ").replace("'d ", " 'd ").replace("'ve ", " 've ")
    normTweet = normTweet.replace(" p . m .", "  p.m.").replace(" p . m ", " p.m ").replace(" a . m ."," a.m.").replace(" a . m "," a.m ")
    normTweet = re.sub(r",([0-9]{2,4}) , ([0-9]{2,4})", r",\1,\2", normTweet)
    normTweet = re.sub(r"([0-9]{1,3}) / ([0-9]{2,4})", r"\1/\2", normTweet)
    normTweet = re.sub(r"([0-9]{1,3})- ([0-9]{2,4})", r"\1-\2", normTweet)
    normTweet = normTweet.lower()
    return " ".join(normTweet.split())


from sklearn.model_selection import train_test_split

# this method just for splitting 
def splitting_method(df_, name1 ,name2, test_size = 0.5):
  y = pd.DataFrame(df_, columns = ["label"])  
  X = pd.DataFrame(df_, columns = ['sentence'])

  X_train, X_test ,y_train, y_test = train_test_split(X, y, test_size=test_size, shuffle=True, random_state=105)

  df_t = pd.DataFrame(X_train, columns = ['sentence'])
  df_yt = pd.DataFrame(y_train, columns = ['label'])

  train_data = pd.concat([df_t, df_yt], axis=1)
  train_data.to_csv(name1+".csv", index = False,)

  df_xtest = pd.DataFrame(X_test, columns = ['sentence'])
  df_ytest = pd.DataFrame(y_test, columns = ['label'])

  test_data = pd.concat([df_xtest, df_ytest], axis=1)
  # print(test_data.isnull().sum())
  if test_size != 0.5:
        
    test_data = test_data.drop_duplicates('sentence')
    
  test_data.to_csv(name2+".csv", index = False,)

  # return train_data
  return test_data

> ### Usign the GPU or CPR

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

> ### Read the original SemEval dataset

In [None]:
temp = pd.concat(map(pd.read_csv, [ '/home/jovyan/conda-envs/dj/DATA/Karim_140K_es.csv','/home/jovyan/conda-envs/dj/DATA/Karim_300K_it.csv','/home/jovyan/conda-envs/dj/DATA/Karim_450K_fr.csv']))
temp.to_csv("training_es_it_fr.csv", index = False,)
df_3L_train = pd.read_csv('training_es_it_fr.csv', skiprows=1, names=['Tweet','Label_2','label','sentence'])

df1 = pd.read_csv('/home/jovyan/conda-envs/dj/DATA/test/us_test.text', sep='\n\n', names=['sentence'])
df2 = pd.read_csv('/home/jovyan/conda-envs/dj/DATA/test/us_test.labels', sep='\n\n', names=['label'])

df = pd.concat([df1, df2], axis=1)
df
df.to_csv("devFile.csv", index = False,)


df1 = pd.read_csv('/home/jovyan/conda-envs/dj/DATA/us/tweet_by_ID_28_1_2019__06_28_21.txt.text', sep='\n\n', names=['sentence'])
df2 = pd.read_csv('/home/jovyan/conda-envs/dj/DATA/us/tweet_by_ID_28_1_2019__06_28_21.txt.labels', sep='\n\n', names=['label'])


df = pd.concat([df1, df2], axis=1)


In [None]:
# you can call it :)
splitting_method(df,'centralized_dataset','fedrated_dataset',test_size = 0.5)

> ### Model intialization

In [None]:
CHECKPOINT = "Twitter/twhin-bert-base"

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    CHECKPOINT, 
    num_labels = 20,   
    output_attentions = False,
    output_hidden_states = False,
)

> ### Load the pretrained model.

In [None]:
model = torch.load('model path', map_location=torch.device(device))

model.to(device)

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

> ### Dataframe to dataloader function

In [None]:
def df_to_dataloader(data_frame):

  # trainFile = '/content/centralized_dataset.csv'
  # devFile = '/content/devFile.csv'
  # df = pd.read_csv(trainFile)
  # df_dev = pd.read_csv(devFile)

  df = data_frame

  print('Number of training sentences: {:,}\n'.format(df.shape[0]))
  # print('Number of dev sentences: {:,}\n'.format(df_dev.shape[0]))
  df['sentence']  = df.sentence.apply(normalizeTweet)
  df.dropna()
  # df_dev['sentence']  = df_dev.sentence.apply(normalizeTweet)
  # df_dev.dropna()



  # Get the lists of sentences and their labels.
  sentences = df.sentence.values
  labels = df.label.values
  # sentences_dev = df_dev.sentence.values
  # labels_dev = df_dev.label.values
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)


  input_ids = []
  # input_ids_dev = []

  for sent in sentences:
    encoded_sent = tokenizer.encode(sent)
    input_ids.append(encoded_sent)

  # for sent_dev in sentences_dev:
  #   encoded_sent_dev = tokenizer.encode(sent_dev)
  #   input_ids_dev.append(encoded_sent_dev)


  MAX_LEN = 64
  #MAX_LEN = 128
  print('\nPadding/truncating all sentences to %d values...' % MAX_LEN)
  print('\nPadding token: "{:}", ID: {:}'.format(tokenizer.pad_token, tokenizer.pad_token_id))
  input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", value=0, truncating="post", padding="post")
  # input_ids_dev = pad_sequences(input_ids_dev, maxlen=MAX_LEN, dtype="long", value=0, truncating="post", padding="post")
  print('\nDone.')
  # Create attention masks


  attention_masks = []
  # attention_masks_dev = []
  for sent in input_ids:
      att_mask = [int(token_id > 0) for token_id in sent]
      attention_masks.append(att_mask)

  # for sent_dev in input_ids_dev:
  #     att_mask_dev = [int(token_id > 0) for token_id in sent_dev]
  #     attention_masks_dev.append(att_mask_dev)


  train_inputs = input_ids
  # validation_inputs = input_ids_dev

  train_labels = labels
  print("train_labels: ",set(train_labels))
  
  # validation_labels = labels_dev
  # print("validation_labels: ",set(validation_labels))
  
  train_masks = attention_masks
  # validation_masks = attention_masks_dev
  
  train_inputs = torch.tensor(train_inputs)
  # validation_inputs = torch.tensor(validation_inputs)
  
  train_labels = torch.tensor(train_labels)
  print("train_labels: ",train_labels)
  
  # validation_labels = torch.tensor(validation_labels)
  # print("validation_labels: ",validation_labels)
  
  train_masks = torch.tensor(train_masks)
  # validation_masks = torch.tensor(validation_masks)


  batch_size = 64
  # Create the DataLoader for our training set.
  train_data = TensorDataset(train_inputs, train_masks, train_labels)
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
  # Create the DataLoader for our validation set.
  # validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
  # validation_sampler = SequentialSampler(validation_data)
  # validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

  return train_dataloader


> ### testing data set

In [None]:
# train_dataloader = Data_to_dataloader("/content/fedrated_dataset.csv")
fed_data = pd.read_csv('fedrated_dataset.csv')
validation_dataloader = Data_to_dataloader('devFile.csv') 


test_German = pd.read_csv('/DATA/german_zero_shot_54K.csv', skiprows=1, names=['Tweet','Label_2','label','sentence'])
test_German.to_csv("test_German.csv", index = False,)

test_old_en = pd.read_csv('/DATA/test_en_27k.csv', skiprows=1, names=['Tweet','Label_2','label','sentence'])
test_old_en.to_csv("test_old_en.csv", index = False,)

test_es = pd.read_csv('/DATA/test_es_29k_final.csv', skiprows=1, names=['Tweet','Label_2','label','sentence'])
test_es.to_csv("test_es.csv", index = False,)

test_fr = pd.read_csv('/DATA/test_fr_63k_final.csv', skiprows=1, names=['Tweet','Label_2','label','sentence'])
test_fr.to_csv("test_fr.csv", index = False,)

test_it = pd.read_csv('/DATA/test_it_36k_final.csv', skiprows=1, names=['Tweet','Label_2','label','sentence'])
test_it.to_csv("test_it.csv", index = False,)


> convert it to dataloader

In [None]:
German_validation_dataloader = Data_to_dataloader('test_German.csv') 
es_validation_dataloader = Data_to_dataloader('test_es.csv') 
fr_validation_dataloader = Data_to_dataloader('test_fr.csv') 
it_validation_dataloader = Data_to_dataloader('test_it.csv') 

> ### Non-IID training dataset

In [None]:
df_train_non_c_zero_toxic = pd.concat(map(pd.read_csv, [ '/DATA/new_toxic_2_4_non_iid/Client_0_EN_non_IID_Toxic_Data.csv']))
df_train_non_c_zero_toxic.to_csv("training_non_iid_c_zero.csv", index = False,)
df_train_non_c_zero_toxic = pd.read_csv('training_non_iid_c_zero.csv', skiprows=1, names=['sentence','Label_2','label'])


df_train_non_c_1 = pd.concat(map(pd.read_csv, [ '/DATA/new_toxic_2_4_non_iid/Client_1_FR_non_IID_Toxic_Data.csv']))
df_train_non_c_1.to_csv("training_non_iid_c_1.csv", index = False,)
df_train_non_c_1 = pd.read_csv('training_non_iid_c_1.csv', skiprows=1, names=['sentence','Label_2','label'])



df_train_non_c_2 = pd.concat(map(pd.read_csv, [ '/DATA/new_toxic_2_4_non_iid/NonIID_It_clients_151K.csv']))
df_train_non_c_2.to_csv("training_non_iid_c_2.csv", index = False,)
df_train_non_c_2 = pd.read_csv('training_non_iid_c_2.csv', skiprows=1, names=['sentence','label','sen_2','Label_2'])



df_train_non_c_3 = pd.concat(map(pd.read_csv, [ '/DATA/new_toxic_2_4_non_iid/NonIID_Es_clients_70K.csv']))
df_train_non_c_3.to_csv("training_non_iid_c_3.csv", index = False,)
df_train_non_c_3 = pd.read_csv('training_non_iid_c_3.csv', skiprows=1, names=['sentence','label','sen_2','Label_2'])


> ### Training

In [None]:
epochs = 1

def train_fun(epoch_i, train_dataloader, model):

  optimizer = AdamW(model.parameters(),
                    lr = 2e-5,
                    eps = 1e-8
                  )
  total_steps = len(train_dataloader) * epochs
  scheduler = get_linear_schedule_with_warmup(optimizer, 
                                              num_warmup_steps = 0, # Default value in run_glue.py
                                              num_training_steps = total_steps)


  # This training code is based on the `run_glue.py` script here:
  # https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128
  seed_val = 42
  random.seed(seed_val)
  np.random.seed(seed_val)
  torch.manual_seed(seed_val)
  torch.cuda.manual_seed_all(seed_val)
  # Store the average loss after each epoch so we can plot them.
  loss_values = []

  # ========================================
  #               Training
  # ========================================
  # Perform one full pass over the training set.
  print("")
  print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
  print('Training...')
  # Measure how long the training epoch takes.
  t0 = time.time()
  # Reset the total loss for this epoch.
  total_loss = 0
  # Put the model into training mode. Don't be mislead--the call to 
  # `train` just changes the *mode*, it doesn't *perform* the training.
  # `dropout` and `batchnorm` layers behave differently during training
  # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
  model.train()

  for step, batch in enumerate(train_dataloader):

      if step % 40 == 0 and not step == 0:
          elapsed = format_time(time.time() - t0)
          print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

      b_input_ids = batch[0].to(device)
      b_input_mask = batch[1].to(device)
      b_labels = batch[2].to(device)
      # Always clear any previously calculated gradients before performing a
      # backward pass. PyTorch doesn't do this automatically because 
      # accumulating the gradients is "convenient while training RNNs". 
      # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
      model.zero_grad()
      # The documentation for this `model` function is here: 
      # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
      outputs = model(b_input_ids, 
                  token_type_ids=None, 
                  attention_mask=b_input_mask, 
                  labels=b_labels)

      loss = outputs[0]
      total_loss += loss.item()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      optimizer.step()
      scheduler.step()
    
      gc.collect()
      torch.cuda.empty_cache()

  avg_train_loss = total_loss / len(train_dataloader)
  loss_values.append(avg_train_loss)
  print("")
  print("  Average training loss: {0:.2f}".format(avg_train_loss))
  print("  Training epcoh took: {:}".format(format_time(time.time() - t0)))

  # name_save = 'xT_bert_tweets_en_semEval_epoch_'+ str(epoch_i) + '.pt'  
  # torch.save(model,name_save)


> ### Validation

In [None]:
def test_fun(validation_dataloader, model):
    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set.
    print("")
    print("Running Validation...")
    t0 = time.time()
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    predictions , true_labels = [], []

    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            outputs = model(b_input_ids, 
                            token_type_ids=None, 
                            attention_mask=b_input_mask)

        logits = outputs[0]
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        eval_accuracy += tmp_eval_accuracy
        nb_eval_steps += 1

        predictions.append(logits)
        true_labels.append(label_ids)
        
        

    print('    DONE.')
    print("  Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps))
    print("  Validation took: {:}".format(format_time(time.time() - t0)))
    
    flat_predictions = [item for sublist in predictions for item in sublist]
    flat_predictions = np.argmax(flat_predictions, axis=1).flatten()
    flat_true_labels = [item for sublist in true_labels for item in sublist]
    # labelsList = ['0','1','2']
    # classif_rep = classification_report(flat_true_labels, flat_predictions, target_names=labelsList)
    classif_rep = classification_report(flat_true_labels, flat_predictions, digits=5)
    print(classif_rep)

    return eval_accuracy/nb_eval_steps , eval_accuracy

### Implementing a Flower client

With that out of the way, let's move on to the interesting part. Federated learning systems consist of a server and multiple clients. In Flower, we create clients by implementing subclasses of `flwr.client.Client` or `flwr.client.NumPyClient`. We use `NumPyClient` in this tutorial because it is easier to implement and requires us to write less boilerplate.

To implement the Flower client, we create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`:

* `get_parameters`: Return the current local model parameters
* `fit`: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the server
* `evaluate`: Receive model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server

We mentioned that our clients will use the previously defined PyTorch components for model training and evaluation. Let's see a simple Flower client implementation that brings everything together:

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, validation_dataloader, es_validation_dataloader, fr_validation_dataloader, it_validation_dataloader, German_validation_dataloader, df_train_non_fr_trainloader, df_train_non_Es_trainloader, df_train_non_It_trainloader, cid):
        self.net = net
        self.trainloader = trainloader
        self.validation_dataloader = validation_dataloader
        self.es_validation_dataloader = es_validation_dataloader
        self.fr_validation_dataloader = fr_validation_dataloader
        self.it_validation_dataloader = it_validation_dataloader
        self.German_validation_dataloader = German_validation_dataloader


        self.cid = cid
        
        self.df_train_non_fr_trainloader = df_train_non_fr_trainloader
        self.df_train_non_Es_trainloader = df_train_non_Es_trainloader
        self.df_train_non_It_trainloader = df_train_non_It_trainloader
        


    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        # train(self.net, self.trainloader, epochs=1)
        
        
        if self.cid == '0':
            train_fun(1, self.trainloader, self.net)
            gc.collect()
            torch.cuda.empty_cache()

        if self.cid == '1':
            train_fun(1, self.df_train_non_fr_trainloader, self.net)
            gc.collect()
            torch.cuda.empty_cache()

        if self.cid == '2':
            train_fun(1, self.df_train_non_Es_trainloader, self.net)
            gc.collect()
            torch.cuda.empty_cache()

            
        if self.cid == '3':
            train_fun(1, self.df_train_non_It_trainloader, self.net)
            gc.collect()
            torch.cuda.empty_cache()

            
            
        return self.get_parameters(config={}), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        # loss, accuracy = test(self.net, self.valloader)

        print('SemEval_test_>>>')
        loss, accuracy = test_fun(self.validation_dataloader, self.net)

        print('ES_test_>>>')
        loss, accuracy = test_fun(self.es_validation_dataloader, self.net)
        
        print('FR_test_>>>')
        loss, accuracy = test_fun(self.fr_validation_dataloader, self.net)
        
        print('IT_test_>>>')
        loss, accuracy = test_fun(self.it_validation_dataloader, self.net)
        
        print('GR_test_>>>')
        loss, accuracy = test_fun(self.German_validation_dataloader, self.net)

        # name_save = '/home/jovyan/conda-envs/dj/models_/n1_ex2_iid_MMini__multilingual_epoch_.pt'  
        # torch.save(self.net, name_save)
        
        mypath = '/home/jovyan/conda-envs/dj/models_'
        onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
        new_name = 'n1_ex2_toxic_2table_2_4_noniid_Mmini_Krum_2g_epoch_R5.pt'
        while(new_name in onlyfiles):
          new_name = new_name[:-3] + "_R_" + new_name[-3:]
        torch.save(self.net, mypath + '/' + new_name)


        return float(loss), len(self.validation_dataloader), {"accuracy": float(accuracy)}
    
    


### Using the Virtual Client Engine

In this notebook, we want to simulate a federated learning system with 4 clients on a single machine. This means that the server and all 10 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 4 clients would mean having 4 instances of `FlowerClient` in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.



In [None]:
NUM_CLIENTS = 4

def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    partition_size = len(fed_data) // NUM_CLIENTS
    idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size


    # train_dataloader
    # trainloader = df_to_dataloader(fed_data.iloc[idx_from:idx_to])
    trainloader = df_to_dataloader(df_train_non_c_zero_toxic)
    
    df_train_non_fr_trainloader = df_to_dataloader(df_train_non_c_1)
    df_train_non_Es_trainloader = df_to_dataloader(df_train_non_c_2)
    df_train_non_It_trainloader = df_to_dataloader(df_train_non_c_3)

    # Create a  single Flower client representing a single organization
    return FlowerClient(model, trainloader, validation_dataloader, es_validation_dataloader, fr_validation_dataloader, it_validation_dataloader, German_validation_dataloader, df_train_non_fr_trainloader, df_train_non_Es_trainloader, df_train_non_It_trainloader, cid)


In [None]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from aggregate import aggregate_krum
from flwr.server.strategy.fedavg import FedAvg

# FedAvg = fl.server.strategy.FedAvg()

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""


In [None]:
# flake8: noqa: E501
class Krum(FedAvg):
    """Configurable Krum strategy implementation."""

    # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        num_malicious_clients: int = 0,
        num_clients_to_keep: int = 0,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        """Configurable Krum strategy.
        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 0.1.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 0.1.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        num_malicious_clients : int, optional
            Number of malicious clients in the system. Defaults to 0.
        num_clients_to_keep : int, optional
            Number of clients to keep before averaging (MultiKrum). Defaults to 0, in that case classical Krum is applied.
        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : Parameters, optional
            Initial global model parameters.
        """

        if (
            min_fit_clients > min_available_clients
            or min_evaluate_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.num_malicious_clients = num_malicious_clients
        self.num_clients_to_keep = num_clients_to_keep

    def __repr__(self) -> str:
        rep = f"Krum(accept_failures={self.accept_failures})"
        return rep

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using Krum."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # Convert results
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        parameters_aggregated = ndarrays_to_parameters(
            aggregate_krum(
                weights_results, self.num_malicious_clients, self.num_clients_to_keep
            )
        )

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated




The only thing left to do is to tell the strategy to call this function whenever it receives evaluation metric dictionaries from the clients:

In [None]:
# strategy = fl.server.strategy.FedAvg(
# Krum(
# Create FedAvg strategy
strategy = Krum(
        # fraction_fit=1.0,  # Sample 100% of available clients for training
        # fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
        # min_fit_clients=10,  # Never sample less than 10 clients for training
        # min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
        # min_available_clients=10,  # Wait until all 10 clients are available
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
# client_resources = None
# if DEVICE.type == "cuda":

client_resources = {"num_gpus": 1/5,  "num_cpus": 32/4}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=1),
    strategy=strategy,
    client_resources=client_resources,
)
