In [1]:
import torch
import matplotlib.pyplot as plt
from model import LogisticRegressionModel
from federated import train_federated
from client import get_dataloaders_per_client, get_local_models
from metrics import compare_global_local, plot_auc_heatmap

device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler

# 1. Chargement des données CSV
df_fr = pd.read_csv("/home/onyxia/work/Federated_Learning_Milliman/data/french_data.csv").dropna()
df_be = pd.read_csv("/home/onyxia/work/Federated_Learning_Milliman/data/belgium_data.csv").dropna()
df_eu = pd.read_csv("/home/onyxia/work/Federated_Learning_Milliman/data/european_data.csv").dropna()

# 2. Nom de la colonne cible
label_col = "Sinistre"  

# 3. Fonction de validation des colonnes
def check_columns_consistency(dfs, label_col):
    base_cols = dfs[0].drop(columns=[label_col]).columns
    for i, df in enumerate(dfs[1:], start=1):
        other_cols = df.drop(columns=[label_col]).columns
        if not all(base_cols == other_cols):
            raise ValueError(f"Les colonnes des DataFrames ne sont pas cohérentes entre df[0] et df[{i}].")

# 4. Fonction de prétraitement et DataLoader
def create_dataloaders_from_dfs(dfs, label_col, batch_size=32, scale=True):
    check_columns_consistency(dfs, label_col)
    loaders = []
    scaler = StandardScaler() if scale else None

    # Ajuster le scaler globalement pour homogénéiser les échelles
    if scale:
        all_X = pd.concat([df.drop(columns=[label_col]) for df in dfs])
        scaler.fit(all_X)

    for df in dfs:
        X = df.drop(columns=[label_col]).values.astype(np.float32)
        y = df[label_col].values.astype(np.float32)
        if scale:
            X = scaler.transform(X)
        X_tensor = torch.tensor(X)
        y_tensor = torch.tensor(y)
        dataset = TensorDataset(X_tensor, y_tensor)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loaders.append(loader)

    return loaders

# 5. Construction des DataLoaders
dfs = [df_fr, df_be, df_eu]
train_loaders = create_dataloaders_from_dfs(dfs, label_col=label_col, batch_size=32)

# 6. Récupération de la dimension d'entrée
input_dim = dfs[0].drop(columns=[label_col]).shape[1]

print(f"DataLoaders créés pour {len(train_loaders)} clients avec input_dim = {input_dim}")





DataLoaders créés pour 3 clients avec input_dim = 6


In [10]:
A = [df_be, df_eu, df_fr]
for i in range(3):
    print(A[i].shape)

(163212, 7)
(2372377, 7)
(1091182, 7)


In [6]:
global_model, metrics = train_federated(
    train_loaders, input_dim=input_dim, algo=algo, T=20, C=1.0,
    E=3, B=32, lr=0.05, mu=0.0,eta=0.01, beta1=0.9, beta2=0.99,
    tau=1e-6, device=device
)

KeyboardInterrupt: 

In [10]:
# Choisir l'algorithme : 'fedavg', 'fedprox', 'fedopt_adam', 'fedopt_yogi', 'fedopt_adagrad'
algo = 'fedavg'

global_model, metrics = train_federated(
    train_loaders=train_loaders,
    input_dim=input_dim,
    algo='fedavg',
    T=2,            # 2 rounds
    C=0.4,    # 50 % des clients
    E=1,            # 1 époque
    B=16,           # plus petit batch
    lr=0.05,
    mu=0.0,
    eta=0.01,
    beta1=0.9,
    beta2=0.99,
    tau=1e-6,
    device=device
)




KeyboardInterrupt: 

In [None]:
import pandas as pd

coefs = pd.DataFrame(metrics["coefficients"], columns=[f"var_{i}" for i in range(input_dim)])
coefs.plot(figsize=(10,6), title="Évolution des coefficients")
plt.xlabel("Round fédéré")
plt.ylabel("Valeur du coefficient")
plt.grid(True)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.tight_layout()
plt.show()


In [None]:
# Récupération des modèles locaux finaux
local_models = get_local_models(train_loaders, global_model, E=3, B=32, lr=0.05, mu=0.1,
                                global_weights=global_model.get_state_dict(), algo=algo, device=device)

results = compare_global_local(global_model, local_models, train_loaders, device=device)

import pandas as pd
results_df = pd.DataFrame(results)
print(results_df)


In [None]:
plot_auc_heatmap(local_models, train_loaders, title=f"Heatmap AUC croisée - {algo.upper()}", device=device)