# WSN-IDS: Machine Learning Based Intrusion Detection for Wireless Sensor Networks

This notebook walks through the full pipeline:
1. Dataset generation & exploration
2. Feature engineering
3. Model training (Random Forest, Decision Tree, KNN, SVM, MLP)
4. Evaluation (Accuracy, Precision, Recall, F1, FPR, Energy Impact)
5. Visualisations
6. Real-time inference demo

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from wsn_ids import ATTACK_LABELS, FEATURE_NAMES
print('Imports OK')

## 1. Generate Dataset

In [None]:
from wsn_ids.data.generate_dataset import generate_dataset
from wsn_ids.features.feature_extraction import add_derived_features

df_raw = generate_dataset(samples_per_class=500, save_path='../results/wsn_dataset.csv')
df = add_derived_features(df_raw)

print(f'Dataset shape : {df.shape}')
print(f'\nClass counts:')
print(df['attack_type'].value_counts())

In [None]:
df.head()

In [None]:
df.describe().round(3)

## 2. Exploratory Data Analysis

In [None]:
from pathlib import Path
from wsn_ids.visualization.plots import (
    plot_class_distribution,
    plot_feature_distributions,
    plot_correlation_heatmap,
    plot_pca_scatter,
)

Path('../results/plots').mkdir(parents=True, exist_ok=True)

all_features = list(FEATURE_NAMES) + [
    'energy_efficiency_index', 'traffic_anomaly_score', 'identity_confusion_index'
]
available = [f for f in all_features if f in df.columns]

plot_class_distribution(df, save_dir='../results/plots')
plot_feature_distributions(df, features=available, save_dir='../results/plots')
plot_correlation_heatmap(df, features=available, save_dir='../results/plots')

print('EDA plots saved.')

In [None]:
# Display class distribution inline
matplotlib.use('Agg')
img = mpimg.imread('../results/plots/class_distribution.png')
plt.figure(figsize=(12, 5))
plt.imshow(img)
plt.axis('off')
plt.tight_layout()
plt.savefig('/tmp/display_class.png', dpi=100)
plt.close()
print('Plot saved to /tmp/display_class.png')

## 3. Feature Engineering & Split

In [None]:
from wsn_ids.features.feature_extraction import split_dataset

X_train, X_test, y_train, y_test, feature_cols = split_dataset(df, test_size=0.20)

print(f'Training samples : {len(X_train)}')
print(f'Test samples     : {len(X_test)}')
print(f'Features used    : {feature_cols}')

In [None]:
# Derived feature stats
derived = ['energy_efficiency_index', 'traffic_anomaly_score', 'identity_confusion_index']
df.groupby('attack_type')[derived].mean().round(3)

## 4. Train ML Models

In [None]:
from wsn_ids.models.train import build_models, train_all, save_models

models = build_models(random_state=42)
print('Training models ...')
train_all(models, X_train, y_train)
save_models(models, save_dir='../results/models')
print('Done.')

## 5. Evaluation

In [None]:
from wsn_ids.models.evaluate import evaluate_all, per_class_report, cross_validate_model

summary, details = evaluate_all(models, X_test, y_test)
print('\nModel Summary:')
summary[['model', 'accuracy', 'precision', 'recall', 'f1_score', 'false_positive_rate', 'energy_impact']]

In [None]:
best_name = summary.iloc[0]['model']
best_pred = details[best_name]['y_pred']
print(f'Best model: {best_name}')
print('\nPer-class classification report:')
per_class_report(y_test, best_pred)

In [None]:
# 5-fold cross-validation on best model
X_all = np.vstack([X_train, X_test])
y_all = np.concatenate([y_train, y_test])

print(f'Running 5-fold CV on {best_name} ...')
cv_df = cross_validate_model(models[best_name], X_all, y_all)
cv_df

## 6. Result Visualisations

In [None]:
from wsn_ids.visualization.plots import (
    plot_confusion_matrix,
    plot_model_comparison,
    plot_feature_importance,
    plot_roc_curves,
    plot_energy_impact,
    plot_cv_scores,
    plot_pca_scatter,
)

plot_model_comparison(summary, save_dir='../results/plots')
plot_energy_impact(summary, save_dir='../results/plots')
plot_cv_scores(cv_df, best_name, save_dir='../results/plots')
plot_confusion_matrix(y_test, best_pred, model_name=best_name, save_dir='../results/plots')
plot_roc_curves(models[best_name], X_test, y_test, model_name=best_name, save_dir='../results/plots')
plot_pca_scatter(X_all, y_all, save_dir='../results/plots')

# Feature importance for Random Forest
rf = models['Random Forest']
plot_feature_importance(
    rf.feature_importances_, feature_cols,
    model_name='Random Forest', save_dir='../results/plots'
)

print('All result plots saved.')

## 7. Real-Time Inference Demo

In [None]:
from wsn_ids.ids import WSNIDS

# Build a WSNIDS instance backed by already-trained models
ids = WSNIDS(results_dir='../results')
ids.df = df
ids.models = models
ids.feature_cols = feature_cols
ids.best_model_name = best_name

scenarios = [
    ('Healthy Sensor',     [0.95, 0.04,  65.0,  5, 0.04,  2.5, 310]),
    ('Suspected Sinkhole', [0.52, 0.48,  70.0, 15, 0.07,  4.0, 295]),
    ('Suspected Sybil',    [0.82, 0.18,  58.0, 32, 0.11,  9.5, 305]),
    ('Suspected DoS',      [0.28, 0.72,  18.0,  6, 0.58,  9.2,  20]),
    ('Hello Flood',        [0.76, 0.24,  38.0, 20, 0.48, 13.0,  30]),
    ('Selective Forward',  [0.35, 0.65,  55.0,  5, 0.05,  3.0, 300]),
    ('Node Compromise',    [0.60, 0.50,  48.0,  9, 0.38,  8.0, 250]),
]

for name, obs in scenarios:
    print(f'\nNode: {name}')
    print(f'  {ids.alert(obs)}')

## 8. Summary Table

In [None]:
cols = ['model', 'accuracy', 'precision', 'recall', 'f1_score', 'false_positive_rate', 'energy_impact', 'inference_ms']
summary[cols].style.highlight_max(
    subset=['accuracy', 'f1_score', 'energy_impact'],
    color='lightgreen'
).highlight_min(
    subset=['false_positive_rate', 'inference_ms'],
    color='lightgreen'
).format(precision=4)