# FedDalf Audio FL ( Federated Domain Adaptation & Lifelong Learning for Audio)
Welcome to this Colab tutorial on federated learning using the FedDalf method!

In this notebook, we will build a federated learning system using FedDalf and PyTorch. In Part 1, we will set up the model training pipeline and data loading with PyTorch. In Part 2, we will introduce FedDalf, a cutting-edge approach that integrates federated learning with domain adaptation and lifelong learning to enhance model performance across different domains.

Explore FedDalf on GitHub ‚≠êÔ∏è to ask questions and get help.

Let's get started! üöÄ


## Pr√©paration (optionnel Colab) / Preparation (optional Colab)
- Monter Drive (`drive.mount('/content/drive')`). / Mount Drive (`drive.mount('/content/drive')`).
- Lancer `install_dependencies()` si l'environnement n'a pas les versions √©pingl√©es. / Run `install_dependencies()` if the environment lacks the pinned versions.


### Installation / Installation
Cette cellule installe les versions √©pingl√©es (TensorFlow/Flower/numpy/imgaug) pour √©viter les conflits en environnement Colab.
This cell installs pinned versions (TensorFlow/Flower/numpy/imgaug) to avoid conflicts in Colab environments.


In [None]:
# --- Installation optionnelle pour Colab / Optional installation for Colab ---
import os

def install_dependencies():
    commands = [
        "pip uninstall -y cryptography numpy",
        "pip install cryptography==44.0.3",
        "pip install numpy==1.26.4",
        "pip install -q flwr[simulation] tensorflow matplotlib smote_variants tfds-nightly scipy",
        "pip install imgaug==0.4.0 --no-deps",
        "pip install --force-reinstall numpy==1.26.4",
        "pip install -U 'flwr[simulation]'",
    ]
    for cmd in commands:
        print(f"Running: {cmd}")
        os.system(cmd)
    try:
        import flwr as _  # noqa: F401
        import imgaug.augmenters as _  # noqa: F401
        from cryptography.hazmat.bindings._rust import PKCS7UnpaddingContext  # noqa: F401
    except ImportError:
        raise SystemExit("Red√©marrer le runtime puis relancer cette cellule.")


### Imports et configuration / Imports and configuration
D√©finit les d√©pendances principales, les constantes globales (dimensions, classes, hyperparam√®tres) et les chemins d'export.
Defines main dependencies, global constants (dimensions, classes, hyperparameters), and export paths.


In [None]:
# --- Imports & configuration / Imports & configuration ---
import re
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from flwr.common import Metrics
import flwr as fl

try:
    import cv2
except ImportError:
    cv2 = None

NUM_CLASSES = 17
INPUT_DIM = (16, 8, 1)
INDEXED_SLICES = [1, 3, 6, 8]
BASE_LR = 1e-4
BATCH_SIZE = 32
EPOCHS = 3
NUM_ROUNDS = 50
FRACTION_CLIENTS = 1.0
MINIMUM_CLIENTS = 10
INITIAL_PATH_ALL_USERS = "/content/drive/MyDrive/FEDADL/history/"
INITIAL_PATH = os.path.join(INITIAL_PATH_ALL_USERS, "evaluation/")


### Utilitaires d'E/S / IO utilities
Fonctions pour cr√©er les dossiers, √©crire/relire les historiques et mettre √† jour les listes persistantes.
Functions to create folders, write/read histories, and update persisted lists.


In [None]:
# --- IO utils / Outils d'E/S ---
import os
from pathlib import Path

def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def saving_history_dict(history_dict: dict, path: str) -> None:
    try:
        with open(path, "a", encoding="utf-8") as f:
            f.write(str(history_dict))
            f.write("")
        print(f"History saved -> {path}")
    except OSError:
        print(f"Unable to write history to {path}")


def load_list_from_file(path: str, round_id: int) -> list:
    if not os.path.exists(path):
        return []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line_dict = json.loads(re.sub("[']", '"', line))
            values = list(line_dict.values())
            if values and values[0] == round_id:
                return values[1]
    return []


def update_list(filename: str, round_id: int, current: list) -> list:
    if round_id == 0:
        return current
    last = load_list_from_file(filename, round_id)
    if not last:
        return current
    return [c if c != -1 else l for c, l in zip(current, last)]


### Chargement et filtrage des donn√©es / Data loading and filtering
Charge les paires features/labels .npy par client, effectue le split train/test et filtre les clients/contenus sur les classes cibl√©es.
Loads per-client .npy feature/label pairs, performs train/test split, and filters clients/records on target classes.


In [None]:
# --- Data loading & filtering / Chargement et filtrage ---
def _load_numpy_pair(folder: str):
    features, labels = None, None
    for file in os.listdir(folder):
        if not file.endswith(".npy"):
            continue
        path = os.path.join(folder, file)
        if "features" in file:
            features = np.load(path)
        elif "labels" in file:
            labels = np.load(path)
    return features, labels


def make_client_data(client_folders):
    client_data = []
    for folder in client_folders:
        if not os.path.exists(folder):
            continue
        X, y = _load_numpy_pair(folder)
        if X is None or y is None:
            continue
        y_cat = to_categorical(y, num_classes=NUM_CLASSES)
        X_train, X_test, y_train, y_test = train_test_split(X, y_cat, random_state=1)
        X_train = X_train.reshape(len(X_train), *INPUT_DIM)
        X_test = X_test.reshape(len(X_test), *INPUT_DIM)
        client_data.append((X_train, y_train, X_test, y_test))
    return client_data


def resize_images(x_train, img_size):
    if cv2 is None:
        raise ImportError("cv2 not available; install opencv-python to use resize_images")
    return np.stack([cv2.resize(img, (img_size, img_size)) for img in x_train])


def number_of_labels(y_train):
    y_arr = np.asarray(y_train)
    if y_arr.ndim > 1:
        return int(np.sum(~np.isnan(y_arr[:, 0])))
    return int(np.sum(~np.isnan(y_arr)))


def renew_list(size: int = NUM_CLASSES):
    return [0] * size


def filter_clients_by_classes(all_x, all_y, indexed_slices):
    selected_x, selected_y = [], []
    for x_train, y_train in zip(all_x, all_y):
        class_counts = renew_list()
        for label in y_train:
            class_counts[int(np.argmax(label))] += 1
        if any(class_counts[i] == 0 for i in indexed_slices):
            continue
        mask = np.isin(np.argmax(y_train, axis=1), indexed_slices)
        selected_x.append(x_train[mask])
        selected_y.append(y_train[mask])
    return selected_x, selected_y


### Gestion des labels / Label handling
Outils pour simuler des labels manquants, g√©n√©rer/mettre √† jour des pseudo-labels et s√©parer jeux √©tiquet√©s vs non √©tiquet√©s.
Tools to simulate missing labels, generate/update pseudo-labels, and split labeled vs unlabeled sets.


In [None]:
# --- Label utilities / Gestion des labels ---
def disturb_labels(y_train, n):
    y_copy = np.asarray(y_train, dtype=float).copy()
    if n <= 0:
        return y_copy
    cut = max(len(y_copy) - n, 0)
    y_copy[cut:] = np.nan
    return y_copy


def generate_one_hotpot_vector(position, size):
    vector = [0.0] * size
    vector[position] = 1.0
    return vector


def map_predict(y_pred, threshold):
    updated, count = [], 0
    for row in y_pred:
        pos = int(np.argmax(row))
        acc = float(row[pos])
        if acc >= threshold:
            updated.append(generate_one_hotpot_vector(pos, len(row)))
            count += 1
        else:
            updated.append(np.nan)
    return np.array(updated, dtype=object), count


def update_y_train(y_train, y_pred):
    y_train = np.asarray(y_train)
    y_pred = np.asarray(y_pred, dtype=object)
    for idx, label in enumerate(y_train):
        has_label = (not isinstance(label, float)) and (not np.isnan(label).any()) if hasattr(label, "any") else not np.isnan(label)
        if has_label:
            y_pred[idx] = label
    return y_pred


def get_labeled_set(x_train, y_train):
    y_arr = np.asarray(y_train)
    mask = ~np.isnan(y_arr).any(axis=1) if y_arr.ndim > 1 else ~np.isnan(y_arr)
    return x_train[mask], y_arr[mask]


def get_unlabeled_set(x_train, y_train):
    y_arr = np.asarray(y_train)
    mask = np.isnan(y_arr).any(axis=1) if y_arr.ndim > 1 else np.isnan(y_arr)
    return x_train[mask], y_arr[mask]


### Aides clients/cat√©gories / Client/category helpers
Normalisation des listes par identifiant client et s√©lection des statuts/cat√©gories pour le suivi f√©d√©r√©.
Normalize lists by client id and select statuses/categories for federated tracking.


In [None]:
# --- Client/category helpers / Aides clients-cat√©gories ---
def get_normalized_list(nb_total, clients_name, client_x):
    normalized = [-1] * nb_total
    for client, elt in zip(clients_name, client_x):
        normalized[int(client)] = elt
    return normalized


def get_selected_categorie_set(nb_total, clients_name, clients_status, categorie_list):
    normalized_status = get_normalized_list(nb_total, clients_name, clients_status)
    return [1 if status in categorie_list else 0 for status in normalized_status]


def create_dictionary(names_list, values_list):
    return dict(zip(names_list, values_list))


### Mod√®le CNN / CNN model
D√©finit et compile le CNN Keras (tanh, 17 classes) utilis√© par tous les clients f√©d√©r√©s.
Defines and compiles the Keras CNN (tanh, 17 classes) used by all federated clients.


In [None]:
# --- Model definition / D√©finition du mod√®le ---
def create_keras_model():
    model = models.Sequential([
        layers.Conv2D(64, (3, 3), padding="same", activation="tanh", input_shape=INPUT_DIM),
        layers.MaxPool2D(pool_size=(2, 2)),
        layers.Conv2D(128, (3, 3), padding="same", activation="tanh"),
        layers.MaxPool2D(pool_size=(2, 2)),
        layers.Dropout(0.1),
        layers.Flatten(),
        layers.Dense(1024, activation="tanh"),
        layers.Dense(NUM_CLASSES, activation="softmax"),
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=BASE_LR),
                  loss="categorical_crossentropy", metrics=["accuracy"])
    return model


### Client Flower et strat√©gie / Flower client and strategy
Impl√©mente le client Flower (fit/evaluate), la strat√©gie FedAvg personnalis√©e et la configuration d'agr√©gation/logs.
Implements the Flower client (fit/evaluate), customized FedAvg strategy, and aggregation/logging setup.


In [None]:
# --- Flower client & strategy / Client Flower et strat√©gie ---
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_x, train_y, val_x, val_y, cid):
        self.model = model
        self.train_x = train_x
        self.train_y = train_y
        self.val_x = val_x
        self.val_y = val_y
        self.cid = cid

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        history = self.model.fit(self.train_x, self.train_y, epochs=EPOCHS,
                                 validation_data=(self.val_x, self.val_y), verbose=0)
        current_round = config.get("current_round", 0)
        client_name = f"client_{self.cid}"
        saving_history_dict({f"round{current_round}": history.history}, os.path.join(INITIAL_PATH, f"{client_name}.txt"))
        loss, acc = self.model.evaluate(self.val_x, self.val_y, verbose=0)
        saving_history_dict({f"round{current_round}": {"Local_loss": [loss], "Local_accuracy": [acc]}},
                            os.path.join(INITIAL_PATH, f"Local_{client_name}.txt"))
        return self.model.get_weights(), len(self.train_x), {"cid": self.cid}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        current_round = config.get("current_round", 0)
        loss, acc = self.model.evaluate(self.val_x, self.val_y, verbose=0)
        client_name = f"client_{self.cid}"
        saving_history_dict({f"round{current_round}": {"Global_loss": [loss], "Global_accuracy": [acc]}},
                            os.path.join(INITIAL_PATH, f"Eval_{client_name}.txt"))
        if current_round == NUM_ROUNDS:
            try:
                self.model.save(os.path.join(INITIAL_PATH, "model"))
            except Exception:
                pass
        return float(loss), len(self.val_x), {"cid": self.cid, "accuracy": float(acc), "loss": float(loss), "round": current_round}


def weighted_average(metrics):
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    losses = [num_examples * m["loss"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    global_accuracy = sum(accuracies) / sum(examples)
    global_loss = sum(losses) / sum(examples)
    current_round = metrics[0][1].get("round", 0) if metrics else 0
    saving_history_dict({f"round{current_round}": {"eval_loss": [global_loss], "eval_accuracy": [global_accuracy]}},
                        os.path.join(INITIAL_PATH, "Evaluation.txt"))
    return {"accuracy": global_accuracy}


def fit_config(server_round: int):
    return {"batch_size": BATCH_SIZE, "current_round": server_round, "local_epochs": EPOCHS}


def eval_config(server_round: int):
    return {"current_round": server_round}


class SaveModelStrategy(fl.server.strategy.FedAvg):
    def configure_fit(self, server_round, parameters, client_manager):
        client_fit_ins_list = super().configure_fit(server_round, parameters, client_manager)
        selected = [client.cid for client, _ in client_fit_ins_list]
        clients_status = [1 if str(i) in selected else 0 for i in range(len(client_manager.all()))]
        saving_history_dict(create_dictionary(["round", "status"], [server_round, clients_status]),
                            os.path.join(INITIAL_PATH, "selected.txt"))
        return client_fit_ins_list

    def aggregate_fit(self, server_round, results, failures):
        for _, parameters in results:
            print("Client:", parameters.metrics.get("cid"))
        return super().aggregate_fit(server_round, results, failures)


def client_fn_builder(all_x, all_y):
    def client_fn(cid: str):
        idx = int(cid)
        x_train, x_test, y_train, y_test = train_test_split(all_x[idx], all_y[idx], test_size=0.2, random_state=42)
        model = create_keras_model()
        return FlowerClient(model, x_train, y_train, x_test, y_test, cid).to_client()
    return client_fn


def start_federated_simulation(all_x, all_y):
    ensure_dir(INITIAL_PATH)
    num_clients = len(all_x) if all_x else MINIMUM_CLIENTS
    strategy = SaveModelStrategy(
        fraction_fit=FRACTION_CLIENTS,
        fraction_evaluate=FRACTION_CLIENTS,
        min_fit_clients=max(MINIMUM_CLIENTS, num_clients),
        min_evaluate_clients=num_clients,
        min_available_clients=max(MINIMUM_CLIENTS, num_clients),
        on_fit_config_fn=fit_config,
        on_evaluate_config_fn=eval_config,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=fl.common.ndarrays_to_parameters(create_keras_model().get_weights()),
    )
    fl.simulation.start_simulation(
        client_fn=client_fn_builder(all_x, all_y),
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
        strategy=strategy,
    )


### Orchestration / Orchestration
Pipeline principal : initialise les chemins, charge et filtre les donn√©es, affiche la distribution puis lance la simulation f√©d√©r√©e.
Main pipeline: initialize paths, load/filter data, print distributions, then launch federated simulation.


In [None]:
# --- Orchestration / Orchestration ---
def main():
    ensure_dir(INITIAL_PATH)
    client_folders = [
        "/content/drive/MyDrive/numpyDataset",
        "/content/drive/MyDrive/urbansound8k",
    ]
    data_all = make_client_data(client_folders)
    all_X_train, all_y_train = [], []
    for x_tr, y_tr, _, _ in data_all:
        all_X_train.append(np.array(x_tr))
        all_y_train.append(np.array(y_tr))
    if not all_X_train:
        print("‚ö†Ô∏è No data loaded. Check Google Drive paths.")
        return
    new_all_x, new_all_y = filter_clients_by_classes(all_X_train, all_y_train, INDEXED_SLICES)
    for idx, labels in enumerate(new_all_y):
        counts = renew_list()
        for lbl in labels:
            counts[int(np.argmax(lbl))] += 1
        print(f"Client {idx} distribution: {counts} (total={sum(counts)})")
    start_federated_simulation(new_all_x, new_all_y)

# main()  # Uncomment to launch directly
