In [None]:
# --- Notebook: 08_single_task_attack_mlp.ipynb ---
# Goal:
# - Single-task Intrusion Detection (attack_id) on CIC IoT-IDAD 2024 packet-based CSVs
# - Ignore device_id for now
# - Use robust preprocessing, simple but strong MLP baseline
# - Evaluate if we can approach reported high attack accuracy

# ============================================================
# 0. Environment & paths
# ============================================================

import sys
import os
from pathlib import Path
import json

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import classification_report, confusion_matrix

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

PROJECT_ROOT = Path(os.getcwd()).resolve().parents[0]
sys.path.append(str(PROJECT_ROOT))

DATA_DIR = PROJECT_ROOT / "data"
PROCESSED_DIR = DATA_DIR / "processed"
REPORTS_DIR = PROJECT_ROOT / "reports"
REPORTS_DIR.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("Using processed data from:", PROCESSED_DIR)

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

# ============================================================
# 1. Config
# ============================================================

TRAIN_PATH = PROCESSED_DIR / "packets_train.csv"
VAL_PATH   = PROCESSED_DIR / "packets_val.csv"
TEST_PATH  = PROCESSED_DIR / "packets_test.csv"

ATTACK_LABEL_MAP_PATH = PROCESSED_DIR / "attack_label_mapping.json"

BATCH_SIZE    = 512
NUM_EPOCHS    = 30   # enough to see convergence
LEARNING_RATE = 3e-4
WEIGHT_DECAY  = 1e-4
MAX_GRAD_NORM = 1.0
PATIENCE      = 5    # early stopping

LABEL_SMOOTHING = 0.0  # start with plain CE for reproduction

# ============================================================
# 2. Load data
# ============================================================

print("\nLoading train/val/test CSVs (full)...")
train_df = pd.read_csv(TRAIN_PATH)
val_df   = pd.read_csv(VAL_PATH)
test_df  = pd.read_csv(TEST_PATH)

print("Train shape:", train_df.shape)
print("Val   shape:", val_df.shape)
print("Test  shape:", test_df.shape)

with open(ATTACK_LABEL_MAP_PATH, "r") as f:
    attack_label_mapping = json.load(f)["id_to_attack"]

num_attacks = len(attack_label_mapping)
print("Number of attack classes:", num_attacks)

# ============================================================
# 3. Single-task target + features
# ============================================================

TARGET_COL = "attack_id"

numeric_cols = train_df.select_dtypes(include=["int64", "float64"]).columns.tolist()
feature_cols = [c for c in numeric_cols if c != TARGET_COL]

print("\nNumber of feature columns:", len(feature_cols))
print("Example features:", feature_cols[:15])

# Sanity: ensure target present
if TARGET_COL not in train_df.columns:
    raise ValueError(f"{TARGET_COL} not found in train_df.")

# ============================================================
# 4. Robust preprocessing (no NaN/Inf)
# ============================================================

def clean_df(df, feature_cols, name):
    # Replace inf with NaN, then fillna with 0
    df[feature_cols] = df[feature_cols].replace([np.inf, -np.inf], np.nan)
    nan_before = df[feature_cols].isna().sum().sum()
    if nan_before > 0:
        print(f"  [{name}] NaN before fill: {nan_before}, filling with 0.")
        df[feature_cols] = df[feature_cols].fillna(0)
    return df

print("\nCleaning NaN/Inf...")
train_df = clean_df(train_df, feature_cols, "train")
val_df   = clean_df(val_df, feature_cols, "val")
test_df  = clean_df(test_df, feature_cols, "test")

# Standardise on train stats
print("\nStandardising features...")
means = train_df[feature_cols].mean()
stds  = train_df[feature_cols].std().replace(0, 1.0)

for df, name in [(train_df, "train"), (val_df, "val"), (test_df, "test")]:
    df[feature_cols] = (df[feature_cols] - means) / stds
    # Final clip to avoid extreme values
    df[feature_cols] = df[feature_cols].clip(-10, 10)
    n_nan = df[feature_cols].isna().sum().sum()
    n_inf = np.isinf(df[feature_cols].values).sum()
    print(f"  [{name}] NaN after std: {n_nan}, Inf: {n_inf}")
    if n_nan > 0 or n_inf > 0:
        raise ValueError(f"Found NaN/Inf in {name} after standardisation.")

# Optionally save scaler
scaler_path = PROCESSED_DIR / "single_task_attack_scaler.json"
with open(scaler_path, "w") as f:
    json.dump({"means": means.to_dict(), "stds": stds.to_dict()}, f, indent=2)

# ============================================================
# 5. Dataset / DataLoader
# ============================================================

class AttackDataset(Dataset):
    def __init__(self, df: pd.DataFrame, feature_cols, target_col="attack_id"):
        self.X = df[feature_cols].values.astype(np.float32)
        self.y = df[target_col].values.astype(np.int64)
        assert not np.isnan(self.X).any()
        assert not np.isinf(self.X).any()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = AttackDataset(train_df, feature_cols)
val_dataset   = AttackDataset(val_df, feature_cols)
test_dataset  = AttackDataset(test_df, feature_cols)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          drop_last=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          drop_last=False, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          drop_last=False, num_workers=0)

print("\nDataset sizes:")
print("  Train:", len(train_dataset))
print("  Val  :", len(val_dataset))
print("  Test :", len(test_dataset))

# ============================================================
# 6. Single-task MLP model
# ============================================================

class AttackMLP(nn.Module):
    def __init__(self, num_features, num_classes, hidden_dims=(512, 256, 128), dropout=0.3):
        super().__init__()
        layers = []
        in_dim = num_features
        for h in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, h),
                nn.BatchNorm1d(h),
                nn.ReLU(),
                nn.Dropout(dropout),
            ])
            in_dim = h
        self.backbone = nn.Sequential(*layers)
        self.head = nn.Linear(in_dim, num_classes)

        # init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        x = self.backbone(x)
        return self.head(x)

num_features = len(feature_cols)
print("\nBuilding single-task MLP:")
print("  num_features:", num_features)
print("  num_attacks :", num_attacks)

model = AttackMLP(
    num_features=num_features,
    num_classes=num_attacks,
    hidden_dims=(512, 256, 128),
    dropout=0.3,
).to(device)

print("Total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

# quick sanity
x_dummy = torch.randn(8, num_features, device=device)
with torch.no_grad():
    logits_dummy = model(x_dummy)
    print("Dummy logits shape:", logits_dummy.shape)

# ============================================================
# 7. Loss, optimizer, scheduler
# ============================================================

criterion = nn.CrossEntropyLoss()  # no label smoothing for baseline

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

total_steps = len(train_loader) * NUM_EPOCHS
warmup_steps = int(0.05 * total_steps)  # 5% warmup

def get_warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_warmup_cosine_schedule(optimizer, warmup_steps, total_steps)
print(f"\nScheduler: warmup_steps={warmup_steps}, total_steps={total_steps}")

# ============================================================
# 8. Training / evaluation
# ============================================================

def train_one_epoch(model, loader, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()

        batch_size = x.size(0)
        total_loss += loss.item() * batch_size
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total_samples += batch_size

        if (batch_idx + 1) % 200 == 0:
            batch_acc = (preds == y).float().mean().item()
            lr = scheduler.get_last_lr()[0]
            print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(loader)} | "
                  f"Loss: {loss.item():.4f} | Batch Acc: {batch_acc:.3f} | LR: {lr:.2e}")

    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc


def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    all_y = []
    all_pred = []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            loss = criterion(logits, y)

            batch_size = x.size(0)
            total_loss += loss.item() * batch_size
            preds = logits.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_samples += batch_size

            all_y.append(y.cpu().numpy())
            all_pred.append(preds.cpu().numpy())

    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    y_true = np.concatenate(all_y)
    y_pred = np.concatenate(all_pred)

    return avg_loss, avg_acc, y_true, y_pred

# ============================================================
# 9. Training loop with early stopping
# ============================================================

best_val_acc = 0.0
epochs_no_improve = 0
history = []

best_model_dir = PROJECT_ROOT / "models"
best_model_dir.mkdir(parents=True, exist_ok=True)
best_model_file = best_model_dir / "single_task_attack_mlp_best.pt"

print("\n" + "="*60)
print(f"Starting single-task training for {NUM_EPOCHS} epochs")
print("="*60)

global_step = 0
for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scheduler, epoch)
    val_loss, val_acc, y_val_true, y_val_pred = evaluate(model, val_loader)

    history.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
        "lr": scheduler.get_last_lr()[0],
    })

    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")

    if val_acc > best_val_acc + 1e-4:
        best_val_acc = val_acc
        epochs_no_improve = 0
        torch.save(model.state_dict(), best_model_file)
        print(f"  ✓ New best model saved (val_acc={val_acc:.4f})")
    else:
        epochs_no_improve += 1
        print(f"  No improvement for {epochs_no_improve} epoch(s)")

    if epochs_no_improve >= PATIENCE:
        print("\nEarly stopping triggered.")
        break

# Save history
hist_df = pd.DataFrame(history)
hist_df.to_csv(REPORTS_DIR / "single_task_attack_mlp_history.csv", index=False)

# ============================================================
# 10. Final test evaluation
# ============================================================

print("\n" + "="*60)
print("Evaluating best single-task model on TEST set")
print("="*60)

best_model = AttackMLP(
    num_features=num_features,
    num_classes=num_attacks,
    hidden_dims=(512, 256, 128),
    dropout=0.3,
).to(device)
best_model.load_state_dict(torch.load(best_model_file, map_location=device))

test_loss, test_acc, y_test_true, y_test_pred = evaluate(best_model, test_loader)

print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

print("\nClassification report (attack_id):")
attack_names = [attack_label_mapping[str(i)] for i in range(num_attacks)]
print(
    classification_report(
        y_test_true,
        y_test_pred,
        target_names=attack_names,
        digits=4,
        zero_division=0,
    )
)

print("\nConfusion matrix (attack_id):")
print(confusion_matrix(y_test_true, y_test_pred))

print("\nDone.")

PROJECT_ROOT: /Users/naeemulhassan/naeem-p/Cloud-Deployed-Multitask-IoT-IDS
Using processed data from: /Users/naeemulhassan/naeem-p/Cloud-Deployed-Multitask-IoT-IDS/data/processed
Using device: mps

Loading train/val/test CSVs (full)...
Train shape: (2126280, 139)
Val   shape: (455632, 139)
Test  shape: (455632, 139)
Number of attack classes: 8

Number of feature columns: 120
Example features: ['stream', 'src_port', 'dst_port', 'inter_arrival_time', 'time_since_previously_displayed_frame', 'port_class_dst', 'l4_tcp', 'l4_udp', 'ttl', 'eth_size', 'tcp_window_size', 'payload_entropy', 'handshake_cipher_suites_length', 'handshake_ciphersuites', 'handshake_extensions_length']

Cleaning NaN/Inf...
  [train] NaN before fill: 20652772, filling with 0.
  [val] NaN before fill: 4432636, filling with 0.
  [test] NaN before fill: 4426542, filling with 0.

Standardising features...
  [train] NaN after std: 0, Inf: 0
  [val] NaN after std: 0, Inf: 0
  [test] NaN after std: 0, Inf: 0

Dataset sizes:

KeyboardInterrupt: 

In [2]:
print("\n" + "="*60)
print("Evaluating best single-task model on TEST set")
print("="*60)

best_model = AttackMLP(
    num_features=num_features,
    num_classes=num_attacks,
    hidden_dims=(512, 256, 128),
    dropout=0.3,
).to(device)
best_model.load_state_dict(torch.load(best_model_file, map_location=device))

test_loss, test_acc, y_test_true, y_test_pred = evaluate(best_model, test_loader)

print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

print("\nClassification report (attack_id):")
attack_names = [attack_label_mapping[str(i)] for i in range(num_attacks)]
print(
    classification_report(
        y_test_true,
        y_test_pred,
        target_names=attack_names,
        digits=4,
        zero_division=0,
    )
)

print("\nConfusion matrix (attack_id):")
print(confusion_matrix(y_test_true, y_test_pred))


Evaluating best single-task model on TEST set

Test Loss: 0.1145
Test Accuracy: 0.9583

Classification report (attack_id):
              precision    recall  f1-score   support

      benign     0.9009    0.9728    0.9354     67500
 brute force     0.9982    0.9630    0.9803     19722
        ddos     0.9963    0.9863    0.9913     67500
         dos     0.9880    0.9604    0.9740     67500
       mirai     0.9991    0.9863    0.9926     67500
       recon     0.9653    0.9255    0.9450     67500
    spoofing     0.9819    0.9495    0.9654     67500
   web-based     0.7922    0.8881    0.8374     30910

    accuracy                         0.9583    455632
   macro avg     0.9527    0.9540    0.9527    455632
weighted avg     0.9608    0.9583    0.9590    455632


Confusion matrix (attack_id):
[[65661     2     1    53     6   326   125  1326]
 [  289 18993     1    36     2    27    65   309]
 [  239     3 66576   145     0   311     2   224]
 [ 1290     6   141 64827     9   626    