# Federated Logistic Regression for IoT Intrusion Detection (CYBRIA)

This notebook runs a federated learning experiment on the CYBRIA IoT
network intrusion dataset:

- Centralized logistic regression baseline
- Federated learning with multiple clients
- FedAvg aggregation
- Accuracy comparison


In [None]:
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

from federated.data_loader import (
    load_cybria_base,
    select_feature_columns,
    split_into_clients,
    make_centralized_train_test,
)
from federated.server import FederatedServer


In [None]:
df = load_cybria_base("data/cybria.csv")
print(f"Loaded {len(df)} rows from CYBRIA dataset.")
feature_cols = select_feature_columns(df, max_features=20)
print("Using feature columns:")
print(feature_cols)


In [None]:
X_train, X_test, y_train, y_test = make_centralized_train_test(
    df,
    feature_cols,
    label_col="Label",  # change if your label is named differently
    test_size=0.2,
)

central_model = LogisticRegression(
    max_iter=200,
    solver="lbfgs",
    n_jobs=-1,
)

central_model.fit(X_train, y_train)
y_pred = central_model.predict(X_test)

print("=== Centralized Logistic Regression Performance ===")
print(classification_report(y_test, y_pred, digits=3))
print("Confusion matrix:\n", confusion_matrix(y_test, y_pred))


In [None]:
clients = split_into_clients(
    df,
    feature_cols,
    label_col="Label",
    n_clients=3,
)

server = FederatedServer(clients=clients)
print(f"Created {len(clients)} federated clients.")


In [None]:
server.run_training(num_rounds=5)

for round_idx, accs in enumerate(server.round_accuracies, start=1):
    avg_acc = sum(accs) / len(accs)
    print(f"Round {round_idx}: per-client={accs}, avg={avg_acc:.3f}")


In [None]:
avg_acc_per_round = [sum(accs) / len(accs) for accs in server.round_accuracies]

plt.figure()
plt.plot(range(1, len(avg_acc_per_round) + 1), avg_acc_per_round, marker="o")
plt.xlabel("Federated Round")
plt.ylabel("Average Client Accuracy")
plt.title("Federated Logistic Regression on CYBRIA")
plt.grid(True)
plt.show()
