# **Introduction to FedDalf: Federated Domain Adaptation with Lifelong Learning**


Welcome to this Colab tutorial on federated learning using the FedDalf method!

In this notebook, we will build a federated learning system using FedDalf and PyTorch. In Part 1, we will set up the model training pipeline and data loading with PyTorch. In Part 2, we will introduce FedDalf, a cutting-edge approach that integrates federated learning with domain adaptation and lifelong learning to enhance model performance across different domains.

Explore FedDalf on GitHub ‚≠êÔ∏è  to ask questions and get help.

Let's get started! üöÄ

## Step 0: Preparation

Before we begin with any actual code, let's make sure that we have everything we need.

## Installing dependencies

Setup and Preprocessing for Federated Learning with `FedDalf` in `PyTorch` and `TensorFlow`

In [None]:
# --- Installation des d√©pendances et Imports / Dependency Installation and Imports ---
import sys

# Uninstall potentially conflicting versions
!pip uninstall -y cryptography numpy

# Install specific versions to avoid conflicts
# Pin cryptography to a version compatible with flwr (e.g., 44.0.3 works with flwr 1.23.0)
!pip install cryptography==44.0.3

# Install numpy version compatible with imgaug 0.4.0
!pip install numpy==1.26.4

# Install other dependencies
# Do not install imgaug with its dependencies yet to prevent numpy upgrade
!pip install -q flwr[simulation] tensorflow matplotlib smote_variants tfds-nightly scipy

# Install imgaug without its dependencies to preserve numpy==1.26.4
!pip install imgaug==0.4.0 --no-deps

# Reinstall numpy==1.26.4 with --force-reinstall to ensure it's the final version
!pip install --force-reinstall numpy==1.26.4

# Ensure flwr is up to date (this should not re-install numpy if 1.26.4 is already there)
!pip install -U 'flwr[simulation]'

print("Dependencies installed.")

# --- V√©rification et Imports / Check and Imports ---
try:
    # Attempt import to check if restart is needed
    import flwr as fl
    import imgaug.augmenters as iaa
    from cryptography.hazmat.bindings._rust import PKCS7UnpaddingContext
except ImportError:
    print("\n" + "!"*80)
    print("\u26a0\ufe0f CRITICAL: RUNTIME RESTART REQUIRED / RED√âMARRAGE REQUIS \u26a0\ufe0f")
    print("The 'cryptography' or 'numpy' library was updated, but old versions are loaded in memory.")
    print("1. Go to: Runtime > Restart session (or Restart Runtime).")
    print("2. Run this cell again.")
    print("\nLa biblioth√®que 'cryptography' ou 'numpy' a √©t√© mise √† jour, mais l'ancienne version est charg√©e.")
    print("1. Allez dans : Ex√©cution > Red√©marrer la session.")
    print("2. Relancez cette cellule.")
    print("!"*80 + "\n")
    sys.exit("Please restart the runtime / Veuillez red√©marrer la session.")

NUM_CLASSES = 11
CLASS_LIST=['car_horn','dog_bark','gun_shot','siren','frog','thunder','cat','rooster','water','cock','baby']
IMG_SIZE = 28
input_dim = (16, 8, 1)

# Imports
import os
import cv2
import json
import PIL
import pandas as pd
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.io
import tensorflow as tf
import keras
import re
import math

from keras.models import Sequential
from keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.applications import resnet50, VGG16, InceptionV3, Xception
from tensorflow.keras import datasets, layers, models
from flwr.common import Metrics
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, LabelBinarizer, OneHotEncoder, RobustScaler, LabelEncoder
from sklearn.neighbors import KNeighborsClassifier
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.linear_model import SGDClassifier
from sklearn.cluster import MiniBatchKMeans
from sklearn import metrics
from sklearn.metrics import silhouette_score, roc_curve, auc, roc_auc_score, confusion_matrix
from sklearn.datasets import load_wine
from imblearn.over_sampling import SMOTE
from collections import OrderedDict
from typing import List, Tuple
from sklearn.manifold import TSNE
from sklearn.random_projection import SparseRandomProjection
from skimage import data
from skimage.transform import rotate
from sklearn.utils import shuffle
from PIL import Image
from keras.layers import Dense

print("Libraries imported successfully.")

Found existing installation: cryptography 44.0.3
Uninstalling cryptography-44.0.3:
  Successfully uninstalled cryptography-44.0.3
Found existing installation: numpy 2.2.6
Uninstalling numpy-2.2.6:
  Successfully uninstalled numpy-2.2.6
Collecting cryptography==44.0.3
  Using cached cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl.metadata (5.7 kB)
Using cached cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl (4.2 MB)
Installing collected packages: cryptography
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
flwr 1.23.0 requires numpy<3.0.0,>=1.26.0, which is not installed.
pyopenssl 24.2.1 requires cryptography<44,>=41.0.5, but you have cryptography 44.0.3 which is incompatible.
pydrive2 1.21.3 requires cryptography<44, but you have cryptography 44.0.3 which is incompatible.[0m[31m
[0mSuccessfully installed cryptography-44.0.3
Collecting n



KeyboardInterrupt: 

# Mount Google Drive for Data Access in Colab.

In [None]:
# --- Montage du Drive / Drive Mounting ---
from google.colab import drive

# Monte le Google Drive dans le r√©pertoire /content/drive
# Mounts Google Drive to the /content/drive directory
drive.mount('/content/drive')

Mounted at /content/drive


# 1. Data Management

Data Loading Functions

In [None]:
def charger_donnees(dossier):
    """
    Charge les donn√©es (features et labels) √† partir d'un dossier sp√©cifi√©.
    Les fichiers doivent √™tre au format .npy.

    Loads data (features and labels) from a specified folder.
    Files must be in .npy format.
    """
    # Parcourir le dossier / Browse the folder
    for fichier in os.listdir(dossier):
        if fichier.endswith('.npy'):  # S'assurer que le fichier est un fichier numpy / Ensure the file is a numpy file
            chemin_fichier = os.path.join(dossier, fichier)
            # Charger les caract√©ristiques et les √©tiquettes √† partir du fichier / Load features and labels from the file
            if 'features' in fichier:
                features = np.load(chemin_fichier)
            elif 'labels' in fichier:
                labels = np.load(chemin_fichier)
    return features, labels

# Fonction pour cr√©er des donn√©es clients / Function to create client data
def make_client_data():
    """
    Charge et pr√©pare les donn√©es pour chaque client simul√©.
    Effectue le split train/test et le reshape des images.

    Loads and prepares data for each simulated client.
    Performs train/test split and image reshaping.
    """
    client_folders = [
                      "/content/drive/MyDrive/numpyDataset",
                      "/content/drive/MyDrive/urbansound8k",
                      # "/content/drive/MyDrive/SoundFedLearning/urbansound8k",
                      # "/content/drive/MyDrive/SoundFedLearning/urbansound8k",
                      ]
    client_data = []
    for folder in client_folders:
        X, Y = charger_donnees(folder)
        Y = to_categorical(Y)
        # Split des donn√©es en train et test / Split data into train and test
        X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1)
        # Redimensionnement pour correspondre √† l'entr√©e du mod√®le (16x8) / Resizing to match model input (16x8)
        X_train = X_train.reshape(len(X_train), 16, 8, 1)
        X_test = X_test.reshape(len(X_test), 16, 8, 1)
        client_data.append(([X_train], [Y_train], [X_test], [Y_test]))
    return client_data

Data Manipulation and Processing Functions:

In [None]:
def resize(x_train,IMG_SIZE):
        """
        Redimensionne une liste d'images vers une taille donn√©e.
        Resizes a list of images to a given size.
        """
        x_train_resized=[]
        for i in range(len(x_train)):
            x_train_resized.append(cv2.resize(x_train[i],(IMG_SIZE, IMG_SIZE)))
        return x_train_resized

Label Management Functions:

In [None]:
def number_of_labels(y_train):
  """
  Compte le nombre d'√©chantillons √©tiquet√©s (non-NaN).
  Counts the number of labeled samples (non-NaN).
  """
  size=len(y_train)
  nb=0
  if isinstance(y_train[0],np.ndarray) or isinstance(y_train[0],list) :
    for i in range(size):
      if np.isnan(y_train[i][0]):
        continue
      nb=nb+1
  else:
    for i in range(size):
      if np.isnan(y_train[i]):
        continue
      nb=nb+1

  return nb

# 2. History Saving and Loading

In [None]:
# Fonction pour sauvegarder l'historique fit et val round par round {loss, val-loss, acc, val-acc}
# Function to save round by round on-device fit and val history {loss, val-loss, acc, val-acc}
def saving_history_dict(history_dict,path):
  try:
    file_dict = open(path, 'a')
    file_dict.write(str(history_dict))
    file_dict.write('\n')
    file_dict.close()
    print("History data saved")
  except:
    print("Unable to write to file")

def load_list_from_file(path,round): #for fit and val only
  """
  Charge une liste √† partir d'un fichier pour un round sp√©cifique.
  Loads a list from a file for a specific round.
  """
  with open(path) as f:
    history={}
    for line_data in f:
        line_dict={}
        line_data=re.sub('[\']', '"',line_data) #replace ' by " in the string
        line_dict=json.loads(line_data)
        entree=list(line_dict.values())
        if entree[0]==round:
           return entree[1]
    return []

def update_list(filename,round,actual_list):
  """
  Met √† jour la liste des statuts, labels et total_size. Le round doit √™tre le pr√©c√©dent.
  Update the list of status, labels and total_size. round must be the previous round.
  """
  if round==0:
    return actual_list

  last_list=load_list_from_file(filename,round)
  #print("Ancienne liste===>",last_list)
  #print("List_actuelle===>",actual_list)
  for i in range(len(last_list)):
    if last_list[i]!=actual_list[i] and actual_list[i]!=-1:
      last_list[i]=actual_list[i]
  print("Retourne==>",last_list)
  return last_list

# 3. Utility Functions (for Clients and Categories)

Client Management Functions:

In [None]:
def get_normalized_list(nbTotalClient,clients_name,client_x):
  """
  Retourne une liste normalis√©e de taille nbTotalClient.
  Returns a normalized list of size nbTotalClient.
  """
  normalized_x=[-1 for _ in range(nbTotalClient)]
  for client,elt in zip(clients_name,client_x):
    normalized_x[client]=elt
  return normalized_x

def get_selected_categorie_set(nbTotalClient,clients_name,clients_status,categorie_list):
  """
  S√©lectionne les clients en fonction de leur statut et des cat√©gories requises.
  0: totalement √©tiquet√©, 1: partiellement √©tiquet√©, 2: autres.

  Selects clients based on their status and required categories.
  0: fully labeled, 1: partially labeled, 2: others.
  """
  normalized_status=get_normalized_list(nbTotalClient,clients_name,clients_status) # Normalize the status list
  selected_clients=[0 for _ in range(nbTotalClient)]
  for i in range(len(normalized_status)):
    if normalized_status[i] in categorie_list:
      selected_clients[i]=1
  return selected_clients

Dictionary Creation Functions:

In [None]:
# Cr√©e un dictionnaire √† partir d'une liste de cl√©s et d'une liste de valeurs
# Creates a dictionary from a list of keys and a list of values
def create_dictionary(Names_list,Values_list):
  dictionary={}
  for key,value in zip(Names_list,Values_list):
    dictionary[key]=value
  return dictionary

# Creating the Audio CNN Model for Classification

In [None]:
# Fonction pour cr√©er un mod√®le Keras
# Function to create a Keras model
def create_keras_model():
    """
    D√©finit l'architecture du mod√®le CNN.
    Defines the CNN model architecture.
    """
    model = Sequential()
    # First convolutional layer with 64 filters, 3x3 kernel, 'same' padding, and ReLU activation
    model.add(Conv2D(64, (3, 3), padding="same", activation="relu", input_shape=input_dim))
    # Max pooling layer to reduce the spatial dimensions by half
    model.add(MaxPool2D(pool_size=(2, 2)))
    # Second convolutional layer with 128 filters, 3x3 kernel, 'same' padding, and ReLU activation
    model.add(Conv2D(128, (3, 3), padding="same", activation="relu"))
    # Another max pooling layer
    model.add(MaxPool2D(pool_size=(2, 2)))
    # Dropout layer with a dropout rate of 0.1 to prevent overfitting
    model.add(Dropout(0.1))
    # Flatten the 3D output from convolutional layers to a 1D vector before the dense layers
    model.add(Flatten())
    # Fully connected dense layer with 1024 units and ReLU activation
    model.add(Dense(1024, activation="relu"))
    # Output layer with 9 units and softmax activation for multi-class classification
    model.add(Dense(9, activation="softmax"))
    return model

Aggregating Client Data for Training and Testing


In [None]:
# Agr√©gation des donn√©es clients pour l'entra√Ænement et le test
# Aggregating Client Data for Training and Testing

# --- D√©finitions de secours / Fallback definitions ---
if 'charger_donnees' not in globals():
    def charger_donnees(dossier):
        features, labels = [], []
        if not os.path.exists(dossier): return np.array([]), np.array([])
        for fichier in os.listdir(dossier):
            if fichier.endswith('.npy'):
                path = os.path.join(dossier, fichier)
                if 'features' in fichier: features = np.load(path)
                elif 'labels' in fichier: labels = np.load(path)
        return features, labels

if 'make_client_data' not in globals():
    def make_client_data():
        client_folders = ["/content/drive/MyDrive/numpyDataset", "/content/drive/MyDrive/urbansound8k"]
        client_data = []
        for folder in client_folders:
            if os.path.exists(folder):
                X, Y = charger_donnees(folder)
                if len(X) > 0:
                    Y = to_categorical(Y)
                    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state=1)
                    X_train = X_train.reshape(len(X_train), 16, 8, 1)
                    X_test = X_test.reshape(len(X_test), 16, 8, 1)
                    client_data.append(([X_train], [Y_train], [X_test], [Y_test]))
        return client_data

# V√©rification et chargement / Check and load
if 'data_all' not in globals():
    print("Chargement des donn√©es... / Loading data...")
    data_all = make_client_data()

all_X_train, all_y_train, all_X_test, all_y_test = [], [], [], []
if 'data_all' in globals() and data_all:
    for data in data_all:
        x_train, y_train, x_test, y_test = data
        all_X_train.append(np.array(x_train[0]))
        all_y_train.append(np.array(y_train[0]))
        all_X_test.append(np.array(x_test[0]))
        all_y_test.append(np.array(y_test[0]))
else:
    print("‚ö†Ô∏è Aucune donn√©e charg√©e. V√©rifiez les chemins d'acc√®s sur Google Drive. / No data loaded. Check Google Drive paths.")

Chargement des donn√©es... / Loading data...


Filtering and Aggregating Clients Based on Selected Classes

In [None]:
# --- Pr√©requis : Chargement des donn√©es / Prerequisites: Data Loading ---
if 'all_y_train' not in globals():
    print("Chargement des donn√©es manquant, ex√©cution de make_client_data()... / Missing data loading, running make_client_data()...")
    # On suppose que make_client_data est d√©fini ou on le d√©finit ici si n√©cessaire
    if 'make_client_data' not in globals():
         # Fallback simple si la cellule pr√©c√©dente n'est pas ex√©cut√©e
         print("Fonction make_client_data non trouv√©e. Veuillez ex√©cuter les cellules de d√©finition des donn√©es.")
    else:
         data_all = make_client_data()
         all_X_train, all_y_train, all_X_test, all_y_test = [], [], [], []
         for data in data_all:
            x_train, y_train, x_test, y_test = data
            all_X_train.append(np.array(x_train[0]))
            all_y_train.append(np.array(y_train[0]))
            all_X_test.append(np.array(x_test[0]))
            all_y_test.append(np.array(y_test[0]))

def renew_list():
  """
  R√©initialise une liste de compteurs pour les 17 classes potentielles.
  Resets a list of counters for the 17 potential classes.
  """
  L=[]
  for i in range(17):
    L.append(0)
  return L

# Indices des classes sp√©cifiques que nous voulons conserver/analyser
# Indices of specific classes we want to keep/analyze
indexed_slices=[1,3,6,8]
clients_to_consider=[]

# 1. Identification des clients valides / Identification of valid clients
# On parcourt tous les clients pour voir s'ils poss√®dent des donn√©es pour les classes cibles (indexed_slices)
if 'all_y_train' in globals():
    for client,elt in enumerate(all_y_train):
      class_flag=True
      L=renew_list()
      for e in elt:
        indix=np.argmax(e)
        L[indix]=L[indix]+1

      # V√©rifie si le client a au moins un exemple pour chaque classe cible
      for i in indexed_slices:
        if L[i]==0:
          class_flag=False
          break
      if class_flag==False:
        print(client,'==>out (Client rejet√© / Client rejected)')
      else:
        clients_to_consider.append(client)
        print(client,'==>in (Client s√©lectionn√© / Client selected)')

    print("Clients retenus / Retained clients:", clients_to_consider)

    # 2. Filtrage des donn√©es / Data Filtering
    new_all_X_train, new_all_y_train, new_all_X_test, new_all_y_test = [], [], [], []
    client=0
    for x_train,y_train in zip(all_X_train,all_y_train):
      x_t,y_t=[],[]
      if client in clients_to_consider:
        for x_,y_ in zip(x_train,y_train):
          if np.argmax(y_) in indexed_slices:
            x_t.append(x_)
            y_t.append(y_)
      client=client+1
      if len(x_t)!=0:
        new_all_X_train.append(np.array(x_t))
        new_all_y_train.append(np.array(y_t))

    # 3. V√©rification de la distribution / Distribution check
    for client,elt in enumerate(new_all_y_train):
      L=renew_list()
      for e in elt:
        indix=np.argmax(e)
        L[indix]=L[indix]+1
      print(f"Distribution Client {client}: {L}, Total: {sum(L)}")

    for i in range(len(new_all_X_train)):
      print(f"Client {i} Shape: X={new_all_X_train[i].shape}, Y={new_all_y_train[i].shape}")
else:
    print("Erreur: Donn√©es non charg√©es (all_y_train).")

0 ==>in (Client s√©lectionn√© / Client selected)
1 ==>in (Client s√©lectionn√© / Client selected)
Clients retenus / Retained clients: [0, 1]
Distribution Client 0: [0, 673, 0, 687, 0, 0, 677, 0, 691, 0, 0, 0, 0, 0, 0, 0, 0], Total: 2728
Distribution Client 1: [0, 350, 0, 752, 0, 0, 274, 0, 696, 0, 0, 0, 0, 0, 0, 0, 0], Total: 2072
Client 0 Shape: X=(2728, 16, 8, 1), Y=(2728, 9)
Client 1 Shape: X=(2072, 16, 8, 1), Y=(2072, 10)


MODEL GLOBAL

In [None]:
from keras.utils import split_dataset
from imgaug import augmenters as iaa

def disturb_labels(y_train,n):
  """
  Simule des √©tiquettes manquantes en rempla√ßant les n derniers labels par NaN.
  Simulates missing labels by replacing the last n labels with NaN.
  """
  size=len(y_train)
  if not(isinstance(y_train[0],np.ndarray)) and not(isinstance(y_train[0],list)) :
    y_train=y_train.astype('float')
  y_train_copy=y_train.copy()
  indice_no_labels=size-n

  if  n==0:
    return y_train
  if indice_no_labels<0 :
    indice_no_labels=0
  for i in range(indice_no_labels,size,1):
    y_train_copy[i]=np.nan
  return y_train_copy

def generate_one_hotpot_vector(position,size):
  """
  G√©n√®re un vecteur one-hot.
  Generates a one-hot vector.
  """
  vector=[0.0 for i in range(size)]
  vector[position]=1.0
  return vector

def max_and_position(L):
  """
  Retourne le max de la liste et sa position.
  Returns the max of the list and its position.
  """
  max_e=max(L)
  return max_e,L.tolist().index(max_e)

def map_predict(Y_pred,threshold):
  """
  Transforme les probabilit√©s en vecteurs one-hot si la confiance d√©passe le seuil.
  Transforms probabilities into one-hot vectors if confidence exceeds threshold.
  """
  count=0
  for i in range(len(Y_pred)):
    size=len(Y_pred[i])
    acc,position=max_and_position(Y_pred[i])
    if acc>=threshold:
      Y_pred[i]=generate_one_hotpot_vector(position,size)
      count=count+1
    else:
      Y_pred[i]=np.nan
  return Y_pred,count

def update_y_train(Y_train,Y_pred):
  """
  Corrige Y_pred avec les nouvelles pseudo-√©tiquettes provenant de Y_train.
  Corrects Y_pred with new pseudo-labels from Y_train.
  """
  assert len(Y_train)==len(Y_pred),"Oh no! Y_train doesn't have the same size as Y_pred!"
  if isinstance(Y_train[0],np.ndarray) or isinstance(Y_train[0],list):
    for i in range(len(Y_train)):
      if np.isnan(Y_train[i][0])==False:
        Y_pred[i]=Y_train[i]
  else:
    for i in range(len(Y_train)):
      if np.isnan(Y_train[i])==False:
        Y_pred[i]=Y_train[i]
  return Y_pred


def get_labeled_set(x_train,y_train):
  """
  Retourne uniquement les √©chantillons √©tiquet√©s.
  Returns only labeled samples.
  """
  size=len(y_train)
  y=[]
  x=[]
  if isinstance(y_train[0],np.ndarray) or isinstance(y_train[0],list):
    for i in range(size):
      if np.isnan(y_train[i][0])==False:
        y.append(y_train[i])
        x.append(x_train[i])
  else:
    for i in range(size):
      if np.isnan(y_train[i])==False:
        y.append(y_train[i])
        x.append(x_train[i])
  return np.array(x),np.array(y)

def get_unlabeled_set(x_train,y_train):
  """
  Retourne uniquement les √©chantillons non √©tiquet√©s.
  Returns only unlabeled samples.
  """
  size=len(y_train)
  y=[]
  x=[]
  if isinstance(y_train[0],np.ndarray) or isinstance(y_train[0],list):
    for i in range(size):
      if np.isnan(y_train[i][0]):
        y.append(y_train[i])
        x.append(x_train[i])
  else:
    for i in range(size):
      if np.isnan(y_train[i]):
        y.append(y_train[i])
        x.append(x_train[i])
  return np.array(x),np.array(y)

AttributeError: `np.sctypes` was removed in the NumPy 2.0 release. Access dtypes explicitly instead.

In [None]:
# --- Test de s√©paration des donn√©es (Client 0) / Data Split Test (Client 0) ---
# V√©rification du split train/test sur le premier client.
# Checking train/test split on the first client.

from sklearn.model_selection import train_test_split

if 'new_all_X_train' in globals() and len(new_all_X_train) > 0:
    X_train, X_test, y_train, y_test = train_test_split(new_all_X_train[0], new_all_y_train[0], test_size=0.33, random_state=42)
    print("Split successful. X_train shape:", X_train.shape)
else:
    print("‚ö†Ô∏è Les donn√©es 'new_all_X_train' ne sont pas d√©finies. Veuillez ex√©cuter la cellule de filtrage des donn√©es pr√©c√©dente.")
    print("‚ö†Ô∏è 'new_all_X_train' data is not defined. Please run the previous data filtering cell.")

In [None]:
# --- Configuration des chemins et param√®tres / Path and Parameter Setup ---
import os

#Change le chemin
initial_path_all_users='/content/drive/MyDrive/FEDADL/history/'
if not os.path.exists(initial_path_all_users):
    os.makedirs(initial_path_all_users)

# Check if new_all_X_train exists before using len()
if 'new_all_X_train' in globals():
    NUM_CLIENTS = len(new_all_X_train)
else:
    NUM_CLIENTS = 10 # Default fallback
    print("NUM_CLIENTS set to default (10) because data is not loaded.")

# Define input_dim if not present (defaulting to the value used elsewhere)
if 'input_dim' not in globals():
    input_dim = (16, 8, 1)

IMG_SHAPE=input_dim
base_learning_rate = 0.0001
def ecrire_dans_fichier(data,fichier_resultat):
    with open(fichier_resultat, 'a') as fichier:
        fichier.write(data + '\n')

In [None]:
# --- Configuration / Configuration ---
if 'initial_path_all_users' not in globals():
    initial_path_all_users = '/content/drive/MyDrive/FEDADL/history/'
    if not os.path.exists(initial_path_all_users):
        os.makedirs(initial_path_all_users)

if 'base_learning_rate' not in globals():
    base_learning_rate = 0.0001

NUM_CLIENTS = 10
FRACTION_CLIENTS=1.0 # Fraction des clients s√©lectionn√©s √† chaque round (1.0 = tous)
MINIMUM_CLIENTS=NUM_CLIENTS # Nombre minimum de clients requis
EPOCHS=3
NUM_CLASS=17
NUM_ROUNDS=50
initial_path= initial_path_all_users+'evaluation/' # Chemin de sauvegarde / Save path
input_dim = (16, 8, 1)

# --- D√©finition du mod√®le / Model Definition ---
def create_keras_model():
    """
    Cr√©e et compile le mod√®le CNN Keras.
    Creates and compiles the Keras CNN model.
    """
    model = Sequential()
    model.add(Conv2D(64, (3, 3), padding = "same", activation = "tanh", input_shape = input_dim))
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (3, 3), padding = "same", activation = "tanh"))
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.1))
    model.add(Flatten())
    model.add(Dense(1024, activation = "tanh"))
    model.add(Dense(17, activation = "softmax"))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),loss='categorical_crossentropy',metrics=['accuracy'])
    return model

model=create_keras_model()

# --- D√©finition du Client Flower / Flower Client Definition ---
class FlowerClient(fl.client.NumPyClient):
  def __init__(self,model,train_data_X,train_data_Y,val_X,val_Y,cid):
     self.model=model
     self.train_data_X=train_data_X
     self.train_data_Y=train_data_Y
     self.val_X=val_X
     self.val_Y=val_Y
     self.cid=cid

  def get_parameters(self, config):
    return model.get_weights()

  def fit(self, parameters, config):
    client_name='client_'+str(self.cid)
    self.model.set_weights(parameters)
    history = self.model.fit(self.train_data_X, self.train_data_Y, epochs=EPOCHS, validation_data=(self.val_X, self.val_Y),verbose=0)

    current_round=config['current_round']
    dict={'round'+str(current_round):history.history}
    path=initial_path+client_name+'.txt'
    saving_history_dict(dict, path)

    path_local_eval=initial_path+'Local_'+client_name+'.txt'
    loss, acc = self.model.evaluate(self.val_X, self.val_Y,verbose=0)
    hist={'Local_loss':[loss], 'Local_accuracy':[acc]}
    saving_history_dict({'round'+str(current_round):hist}, path_local_eval)

    return self.model.get_weights(), len(self.train_data_X), {'cid':self.cid}

  def evaluate(self, parameters, config):
    client_name='client_'+str(self.cid)
    current_round=config['current_round']
    self.model.set_weights(parameters)
    Global_loss, Global_accuracy = self.model.evaluate(self.val_X, self.val_Y,verbose=0)

    hist={'Global_loss':[Global_loss], 'Global_accuracy':[Global_accuracy]}
    path=initial_path+'Eval_'+client_name+'.txt'
    saving_history_dict({'round'+str(current_round):hist}, path)

    try:
      if current_round==NUM_ROUNDS:
        model.save(initial_path+'model')
    except: pass
    return Global_loss, len(self.val_X), {"cid":self.cid,"accuracy": Global_accuracy,"loss":Global_loss,"round":config['current_round']}

def client_fn(cid: str) -> FlowerClient:
    if 'new_all_X_train' not in globals():
        print("Erreur: Donn√©es non filtr√©es (new_all_X_train manque).")
        # Retourner un client factice ou g√©rer l'erreur
        return None
    x_train, x_test, y_train, y_test = train_test_split(new_all_X_train[int(cid)], new_all_y_train[int(cid)], test_size=0.2, random_state=42)
    return FlowerClient(model, x_train,y_train,x_test,y_test,cid).to_client()

# --- C√¥t√© Serveur / Server Side ---
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    losses=[num_examples * m["loss"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    Global_accuracy=sum(accuracies) / sum(examples)
    Global_loss=sum(losses) / sum(examples)

    current_round=metrics[0][1]["round"] if metrics else 0
    print("Total clients for evaluation:",len(accuracies))

    hist={'eval_loss':[Global_loss], 'eval_accuracy':[Global_accuracy]}
    saving_history_dict({'round'+str(current_round):hist}, initial_path+'Evaluation.txt')
    return {"accuracy": Global_accuracy}

def fit_config(server_round: int):
    return {"batch_size": 32, "current_round": server_round, "local_epochs": 3}

def eval_config(server_round: int):
    return {"current_round": server_round}

class SaveModelStrategy(fl.server.strategy.FedAvg):
    def configure_fit(self, server_round, parameters, client_manager):
        print("Clients disponibles / Available clients:",client_manager.all().keys())
        client_fit_ins_list=super().configure_fit(server_round, parameters, client_manager)

        selected_client=[client.cid for (client,_) in client_fit_ins_list]
        clients_status=[1 if str(i) in selected_client else 0 for i in range(len(client_manager.all().keys()))]

        saving_history_dict(create_dictionary(['round','status'],[server_round, clients_status]), initial_path+'selected.txt')
        return client_fit_ins_list

    def aggregate_fit(self,server_round,results,failures):
        for _,parameters in results: print('Client:',parameters.metrics['cid'])
        return super().aggregate_fit(server_round, results, failures)

strategy = SaveModelStrategy(
    fraction_fit=FRACTION_CLIENTS, fraction_evaluate=FRACTION_CLIENTS,
    min_fit_clients=MINIMUM_CLIENTS, min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=MINIMUM_CLIENTS,
    on_fit_config_fn=fit_config, on_evaluate_config_fn=eval_config,
    evaluate_metrics_aggregation_fn=weighted_average,
    initial_parameters=fl.common.ndarrays_to_parameters(model.get_weights()),
)

fl.simulation.start_simulation(
    client_fn=client_fn, num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), strategy=strategy
)

Now that we have all dependencies installed, we can import everything we need for this tutorial: