In [None]:
# --- Script: 07_multitask_cnn_staged.py ---
# Goal:
# - Multitask model with shared CNN backbone:
#     Task 1: Intrusion Detection (attack_id, 8 classes)
#     Task 2: Device Identification (device_id, 94 classes)
# - Stage 1: Train attack head only (single-task, high accuracy)
# - Stage 2: Freeze backbone+attack head, train device head only
# - Stage 3  small joint fine-tuning (attack-dominant loss)

# ============================================================
# 0. Environment & paths
# ============================================================

import sys
import os
from pathlib import Path
import json
import math

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import classification_report, confusion_matrix

# Reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

PROJECT_ROOT = Path(os.getcwd()).resolve().parents[0]
sys.path.append(str(PROJECT_ROOT))

DATA_DIR = PROJECT_ROOT / "data"
PROCESSED_DIR = DATA_DIR / "processed"
REPORTS_DIR = PROJECT_ROOT / "reports"
MODELS_DIR = PROJECT_ROOT / "models"
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("Using processed data from:", PROCESSED_DIR)

# Device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using device:", device)

# ============================================================
# 1. Config
# ============================================================

TRAIN_PATH = PROCESSED_DIR / "packets_train.csv"
VAL_PATH   = PROCESSED_DIR / "packets_val.csv"
TEST_PATH  = PROCESSED_DIR / "packets_test.csv"

ATTACK_LABEL_MAP_PATH = PROCESSED_DIR / "attack_label_mapping.json"
DEVICE_LABEL_MAP_PATH = PROCESSED_DIR / "device_label_mapping.json"

BATCH_SIZE = 512

# Slightly longer training to stabilise backbone
NUM_EPOCHS_STAGE1 = 6   # attack-only training
NUM_EPOCHS_STAGE2 = 6   # device-only training
NUM_EPOCHS_STAGE3 = 6   # optional joint fine-tune

LEARNING_RATE_STAGE1 = 3e-4
LEARNING_RATE_STAGE2 = 3e-4
LEARNING_RATE_STAGE3 = 1e-5  # much smaller for fine-tuning

WEIGHT_DECAY = 1e-4
MAX_GRAD_NORM = 1.0
PATIENCE_STAGE1 = 5
PATIENCE_STAGE2 = 5
PATIENCE_STAGE3 = 3

# Joint fine-tune flag
DO_JOINT_FINETUNE = True

# Loss weights in joint phase: attack dominates
LAMBDA_ATTACK_JOINT = 0.9
LAMBDA_DEVICE_JOINT = 0.1

# Validation selection in Stage 3: bias to attack performance
JOINT_METRIC_ATTACK_WEIGHT = 0.7
JOINT_METRIC_DEVICE_WEIGHT = 0.3

# Label smoothing for attack loss (helps stability slightly)
ATTACK_LABEL_SMOOTH = 0.05

# ============================================================
# 2. Load data
# ============================================================

print("\nLoading processed CSVs...")
train_df = pd.read_csv(TRAIN_PATH)
val_df   = pd.read_csv(VAL_PATH)
test_df  = pd.read_csv(TEST_PATH)

print("Train shape:", train_df.shape)
print("Val   shape:", val_df.shape)
print("Test  shape:", test_df.shape)

with open(ATTACK_LABEL_MAP_PATH, "r") as f:
    attack_label_mapping = json.load(f)["id_to_attack"]
with open(DEVICE_LABEL_MAP_PATH, "r") as f:
    device_label_mapping = json.load(f)["id_to_device"]

num_attacks = len(attack_label_mapping)
num_devices = len(device_label_mapping)
print("Number of attack classes:", num_attacks)
print("Number of device classes:", num_devices)

# ============================================================
# 3. Feature selection
# ============================================================

TARGET_ATTACK = "attack_id"
TARGET_DEVICE = "device_id"

numeric_cols = train_df.select_dtypes(include=["int64", "float64"]).columns.tolist()
feature_cols = [c for c in numeric_cols if c not in [TARGET_ATTACK, TARGET_DEVICE]]

print("\nNumber of feature columns:", len(feature_cols))
print("Example features:", feature_cols[:15])

if TARGET_ATTACK not in train_df.columns or TARGET_DEVICE not in train_df.columns:
    raise ValueError("attack_id/device_id target columns not found in train_df.")

# ============================================================
# 4. Robust preprocessing: NaN/Inf cleaning + standardisation
# ============================================================

def clean_df(df, feature_cols, name):
    df[feature_cols] = df[feature_cols].replace([np.inf, -np.inf], np.nan)
    nan_before = df[feature_cols].isna().sum().sum()
    if nan_before > 0:
        print(f"  [{name}] NaN before fill: {nan_before}, filling with 0.")
        df[feature_cols] = df[feature_cols].fillna(0)
    return df

print("\nCleaning NaN/Inf...")
train_df = clean_df(train_df, feature_cols, "train")
val_df   = clean_df(val_df, feature_cols, "val")
test_df  = clean_df(test_df, feature_cols, "test")

print("\nStandardising features...")
means = train_df[feature_cols].mean()
stds  = train_df[feature_cols].std().replace(0, 1.0)

for df, name in [(train_df, "train"), (val_df, "val"), (test_df, "test")]:
    df[feature_cols] = (df[feature_cols] - means) / stds
    df[feature_cols] = df[feature_cols].clip(-10, 10)
    n_nan = df[feature_cols].isna().sum().sum()
    n_inf = np.isinf(df[feature_cols].values).sum()
    print(f"  [{name}] NaN after std: {n_nan}, Inf: {n_inf}")
    if n_nan > 0 or n_inf > 0:
        raise ValueError(f"Found NaN/Inf in {name} after standardisation.")

# Save scaler (useful for deployment)
scaler_path = PROCESSED_DIR / "multitask_cnn_mlp_scaler.json"
with open(scaler_path, "w") as f:
    json.dump({"means": means.to_dict(), "stds": stds.to_dict()}, f, indent=2)

# ============================================================
# 5. Dataset / DataLoader
# ============================================================

class MultiTaskDataset(Dataset):
    def __init__(self, df: pd.DataFrame, feature_cols, attack_col="attack_id", device_col="device_id"):
        self.X = df[feature_cols].values.astype(np.float32)
        self.y_attack = df[attack_col].values.astype(np.int64)
        self.y_device = df[device_col].values.astype(np.int64)
        assert not np.isnan(self.X).any()
        assert not np.isinf(self.X).any()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y_attack[idx], self.y_device[idx]

train_dataset = MultiTaskDataset(train_df, feature_cols)
val_dataset   = MultiTaskDataset(val_df, feature_cols)
test_dataset  = MultiTaskDataset(test_df, feature_cols)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          drop_last=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          drop_last=False, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          drop_last=False, num_workers=0)

print("\nDataset sizes:")
print("  Train:", len(train_dataset))
print("  Val  :", len(val_dataset))
print("  Test :", len(test_dataset))

# ============================================================
# 6. Model: CNN+MLP backbone + two heads
# ============================================================
# ============================================================
# 6. Model: 1D CNN backbone + two heads
# ============================================================

class CNNBackbone1D(nn.Module):
    """
    Shared CNN backbone:

    Input:
        x: (B, F)  where F = num_features

    Steps:
        - Treat feature dimension as a 1D sequence.
        - Apply several Conv1d + BN + ReLU blocks.
        - Global average pooling over sequence length.
        - Small linear layer to produce shared representation.

    Output:
        rep: (B, rep_dim)
    """
    def __init__(
        self,
        num_features: int,
        conv_channels: int = 64,
        conv_channels_mid: int = 128,
        rep_dim: int = 128,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.num_features = num_features
        self.rep_dim = rep_dim

        # x comes as (B, F) -> (B, 1, F)
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(1, conv_channels, kernel_size=5, padding=2),
            nn.BatchNorm1d(conv_channels),
            nn.ReLU(),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(conv_channels, conv_channels_mid, kernel_size=5, padding=2),
            nn.BatchNorm1d(conv_channels_mid),
            nn.ReLU(),
        )

        # A smaller kernel to refine local interactions
        self.conv_block3 = nn.Sequential(
            nn.Conv1d(conv_channels_mid, conv_channels_mid, kernel_size=3, padding=1),
            nn.BatchNorm1d(conv_channels_mid),
            nn.ReLU(),
        )

        # Optional downsampling (not strictly required, but can help)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        # Global average pooling over the remaining feature dimension
        # then project to rep_dim
        self.proj = nn.Sequential(
            nn.Linear(conv_channels_mid, rep_dim),
            nn.BatchNorm1d(rep_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, F)
        x = x.unsqueeze(1)  # (B, 1, F)

        x = self.conv_block1(x)          # (B, C1, F)
        x = self.conv_block2(x)          # (B, C2, F)
        x = self.conv_block3(x)          # (B, C2, F)
        x = self.pool(x)                 # (B, C2, F/2)

        # Global average over sequence length (last dim)
        x = x.mean(dim=2)                # (B, C2)

        # Project to shared representation
        x = self.proj(x)                 # (B, rep_dim)
        return x


class MultiTaskCNN1D(nn.Module):
    """
    Multitask model with shared CNN backbone and two classification heads:
        - attack_id head
        - device_id head
    """
    def __init__(
        self,
        num_features: int,
        num_attacks: int,
        num_devices: int,
        conv_channels: int = 64,
        conv_channels_mid: int = 128,
        rep_dim: int = 128,
        dropout: float = 0.3,
    ):
        super().__init__()

        self.backbone = CNNBackbone1D(
            num_features=num_features,
            conv_channels=conv_channels,
            conv_channels_mid=conv_channels_mid,
            rep_dim=rep_dim,
            dropout=dropout,
        )

        self.attack_head = nn.Linear(rep_dim, num_attacks)
        self.device_head = nn.Linear(rep_dim, num_devices)

        nn.init.kaiming_normal_(self.attack_head.weight, nonlinearity="linear")
        nn.init.constant_(self.attack_head.bias, 0.0)
        nn.init.kaiming_normal_(self.device_head.weight, nonlinearity="linear")
        nn.init.constant_(self.device_head.bias, 0.0)

    def forward(self, x: torch.Tensor):
        rep = self.backbone(x)
        logits_attack = self.attack_head(rep)
        logits_device = self.device_head(rep)
        return logits_attack, logits_device


# Instantiate model
num_features = len(feature_cols)
print("\nBuilding MultiTask 1D-CNN model:")
print("  num_features:", num_features)
print("  num_attacks :", num_attacks)
print("  num_devices :", num_devices)

model = MultiTaskCNN1D(
    num_features=num_features,
    num_attacks=num_attacks,
    num_devices=num_devices,
    conv_channels=64,
    conv_channels_mid=128,
    rep_dim=128,
    dropout=0.3,
).to(device)

print("Total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

# sanity check
x_dummy = torch.randn(8, num_features, device=device)
with torch.no_grad():
    logits_att_dummy, logits_dev_dummy = model(x_dummy)
    print("Dummy attack logits shape:", logits_att_dummy.shape)
    print("Dummy device logits shape:", logits_dev_dummy.shape)

# ============================================================
# 7. Scheduler helper
# ============================================================

def get_warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ============================================================
# 8. Stage 1: attack-only training
# ============================================================

criterion_attack = nn.CrossEntropyLoss(label_smoothing=ATTACK_LABEL_SMOOTH)
criterion_device = nn.CrossEntropyLoss()  # used later

def train_epoch_attack_only(model, loader, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch_idx, (x, y_att, y_dev) in enumerate(loader):
        x = x.to(device)
        y_att = y_att.to(device)

        optimizer.zero_grad()
        logits_att, _ = model(x)
        loss = criterion_attack(logits_att, y_att)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()

        batch_size = x.size(0)
        total_loss += loss.item() * batch_size
        preds = logits_att.argmax(dim=1)
        total_correct += (preds == y_att).sum().item()
        total_samples += batch_size

        if (batch_idx + 1) % 200 == 0:
            batch_acc = (preds == y_att).float().mean().item()
            lr = scheduler.get_last_lr()[0]
            print(f"[Stage1] Epoch {epoch} | Batch {batch_idx+1}/{len(loader)} | "
                  f"Loss: {loss.item():.4f} | Batch Acc: {batch_acc:.3f} | LR: {lr:.2e}")

    return total_loss / total_samples, total_correct / total_samples


def eval_attack_only(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    all_y = []
    all_pred = []

    with torch.no_grad():
        for x, y_att, y_dev in loader:
            x = x.to(device)
            y_att = y_att.to(device)

            logits_att, _ = model(x)
            loss = criterion_attack(logits_att, y_att)

            batch_size = x.size(0)
            total_loss += loss.item() * batch_size
            preds = logits_att.argmax(dim=1)
            total_correct += (preds == y_att).sum().item()
            total_samples += batch_size

            all_y.append(y_att.cpu().numpy())
            all_pred.append(preds.cpu().numpy())

    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    y_true = np.concatenate(all_y)
    y_pred = np.concatenate(all_pred)

    return avg_loss, avg_acc, y_true, y_pred


print("\n" + "="*60)
print("STAGE 1: Attack-only training (single-task)")
print("="*60)

optimizer_stage1 = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE_STAGE1,
    weight_decay=WEIGHT_DECAY,
)
total_steps_stage1 = len(train_loader) * NUM_EPOCHS_STAGE1
warmup_steps_stage1 = int(0.05 * total_steps_stage1)
scheduler_stage1 = get_warmup_cosine_schedule(optimizer_stage1, warmup_steps_stage1, total_steps_stage1)

best_val_acc_stage1 = 0.0
epochs_no_improve1 = 0
history_stage1 = []
best_attack_model_path = MODELS_DIR / "multitask_cnn_mlp_stage1_attack_pretrained.pt"

for epoch in range(1, NUM_EPOCHS_STAGE1 + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS_STAGE1}")
    train_loss, train_acc = train_epoch_attack_only(model, train_loader, optimizer_stage1, scheduler_stage1, epoch)
    val_loss, val_acc, y_val_true, y_val_pred = eval_attack_only(model, val_loader)

    history_stage1.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
        "lr": scheduler_stage1.get_last_lr()[0],
    })

    print(f"  [Stage1] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  [Stage1] Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")

    if val_acc > best_val_acc_stage1 + 1e-4:
        best_val_acc_stage1 = val_acc
        epochs_no_improve1 = 0
        torch.save(model.state_dict(), best_attack_model_path)
        print(f"  ✓ [Stage1] New best attack-only model saved (val_acc={val_acc:.4f})")
    else:
        epochs_no_improve1 += 1
        print(f"  [Stage1] No improvement for {epochs_no_improve1} epoch(s)")

    if epochs_no_improve1 >= PATIENCE_STAGE1:
        print("\n[Stage1] Early stopping triggered.")
        break

pd.DataFrame(history_stage1).to_csv(REPORTS_DIR / "multitask_stage1_attack_history.csv", index=False)

# Load best attack-only weights before Stage 2
model.load_state_dict(torch.load(best_attack_model_path, map_location=device))

# Evaluate attack head on TEST after Stage 1
print("\n" + "="*60)
print("STAGE 1: Evaluate attack-only model on TEST set")
print("="*60)
test_loss1, test_acc1, y_test_true1, y_test_pred1 = eval_attack_only(model, test_loader)
print(f"[Stage1] Test Loss: {test_loss1:.4f}")
print(f"[Stage1] Test Accuracy: {test_acc1:.4f}")

attack_names = [attack_label_mapping[str(i)] for i in range(num_attacks)]
stage1_report = classification_report(
    y_test_true1,
    y_test_pred1,
    target_names=attack_names,
    digits=4,
    zero_division=0,
)
print("\n[Stage1] Classification report (attack_id):")
print(stage1_report)

with open(REPORTS_DIR / "multitask_stage1_attack_test_report.txt", "w") as f:
    f.write(stage1_report)
    f.write("\n\nConfusion matrix:\n")
    f.write(str(confusion_matrix(y_test_true1, y_test_pred1)))

# ============================================================
# 9. Stage 2: freeze backbone+attack head, train device head only
# ============================================================

print("\n" + "="*60)
print("STAGE 2: Device-only training (backbone + attack head frozen)")
print("="*60)

# Freeze backbone and attack head
for param in model.backbone.parameters():
    param.requires_grad = False
for param in model.attack_head.parameters():
    param.requires_grad = False

# Optimizer only on device head
optimizer_stage2 = torch.optim.AdamW(
    model.device_head.parameters(),
    lr=LEARNING_RATE_STAGE2,
    weight_decay=WEIGHT_DECAY,
)

total_steps_stage2 = len(train_loader) * NUM_EPOCHS_STAGE2
warmup_steps_stage2 = int(0.05 * total_steps_stage2)
scheduler_stage2 = get_warmup_cosine_schedule(optimizer_stage2, warmup_steps_stage2, total_steps_stage2)

def train_epoch_device_only(model, loader, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch_idx, (x, y_att, y_dev) in enumerate(loader):
        x = x.to(device)
        y_dev = y_dev.to(device)

        optimizer.zero_grad()
        _, logits_dev = model(x)
        loss = criterion_device(logits_dev, y_dev)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.device_head.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()

        batch_size = x.size(0)
        total_loss += loss.item() * batch_size
        preds = logits_dev.argmax(dim=1)
        total_correct += (preds == y_dev).sum().item()
        total_samples += batch_size

        if (batch_idx + 1) % 200 == 0:
            batch_acc = (preds == y_dev).float().mean().item()
            lr = scheduler.get_last_lr()[0]
            print(f"[Stage2] Epoch {epoch} | Batch {batch_idx+1}/{len(loader)} | "
                  f"Loss: {loss.item():.4f} | Batch Acc: {batch_acc:.3f} | LR: {lr:.2e}")

    return total_loss / total_samples, total_correct / total_samples


def eval_both_heads(model, loader):
    """
    Evaluate both attack and device heads (no gradients).
    Used in Stage 2 & 3 to see if attack performance is preserved.
    """
    model.eval()
    total_loss_att = 0.0
    total_loss_dev = 0.0
    total_samples = 0

    all_att_true, all_att_pred = [], []
    all_dev_true, all_dev_pred = [], []

    with torch.no_grad():
        for x, y_att, y_dev in loader:
            x = x.to(device)
            y_att = y_att.to(device)
            y_dev = y_dev.to(device)

            logits_att, logits_dev = model(x)
            loss_att = criterion_attack(logits_att, y_att)
            loss_dev = criterion_device(logits_dev, y_dev)

            batch_size = x.size(0)
            total_loss_att += loss_att.item() * batch_size
            total_loss_dev += loss_dev.item() * batch_size
            total_samples += batch_size

            pred_att = logits_att.argmax(dim=1)
            pred_dev = logits_dev.argmax(dim=1)

            all_att_true.append(y_att.cpu().numpy())
            all_att_pred.append(pred_att.cpu().numpy())
            all_dev_true.append(y_dev.cpu().numpy())
            all_dev_pred.append(pred_dev.cpu().numpy())

    avg_loss_att = total_loss_att / total_samples
    avg_loss_dev = total_loss_dev / total_samples

    y_att_true = np.concatenate(all_att_true)
    y_att_pred = np.concatenate(all_att_pred)
    y_dev_true = np.concatenate(all_dev_true)
    y_dev_pred = np.concatenate(all_dev_pred)

    att_acc = (y_att_true == y_att_pred).mean()
    dev_acc = (y_dev_true == y_dev_pred).mean()

    return avg_loss_att, avg_loss_dev, att_acc, dev_acc, y_att_true, y_att_pred, y_dev_true, y_dev_pred


best_dev_acc_stage2 = 0.0
epochs_no_improve2 = 0
history_stage2 = []
best_stage2_model_path = MODELS_DIR / "multitask_cnn_mlp_stage2_device_trained.pt"

for epoch in range(1, NUM_EPOCHS_STAGE2 + 1):
    print(f"\n[Stage2] Epoch {epoch}/{NUM_EPOCHS_STAGE2}")
    train_loss_dev, train_acc_dev = train_epoch_device_only(model, train_loader, optimizer_stage2, scheduler_stage2, epoch)
    val_loss_att2, val_loss_dev2, val_att_acc2, val_dev_acc2, _, _, _, _ = eval_both_heads(model, val_loader)

    history_stage2.append({
        "epoch": epoch,
        "train_loss_dev": train_loss_dev,
        "train_acc_dev": train_acc_dev,
        "val_loss_att": val_loss_att2,
        "val_loss_dev": val_loss_dev2,
        "val_att_acc": val_att_acc2,
        "val_dev_acc": val_dev_acc2,
        "lr": scheduler_stage2.get_last_lr()[0],
    })

    print(f"  [Stage2] Train Dev Loss: {train_loss_dev:.4f}, Train Dev Acc: {train_acc_dev:.4f}")
    print(f"  [Stage2] Val   Att Loss: {val_loss_att2:.4f}, Val Att Acc: {val_att_acc2:.4f}")
    print(f"  [Stage2] Val   Dev Loss: {val_loss_dev2:.4f}, Val Dev Acc: {val_dev_acc2:.4f}")

    # Track best by device accuracy, but keep an eye on attack acc
    if val_dev_acc2 > best_dev_acc_stage2 + 1e-4:
        best_dev_acc_stage2 = val_dev_acc2
        epochs_no_improve2 = 0
        torch.save(model.state_dict(), best_stage2_model_path)
        print(f"  ✓ [Stage2] New best model saved (Val Dev Acc={val_dev_acc2:.4f}, Val Att Acc={val_att_acc2:.4f})")
    else:
        epochs_no_improve2 += 1
        print(f"  [Stage2] No improvement for {epochs_no_improve2} epoch(s)")

    if epochs_no_improve2 >= PATIENCE_STAGE2:
        print("\n[Stage2] Early stopping triggered.")
        break

pd.DataFrame(history_stage2).to_csv(REPORTS_DIR / "multitask_stage2_device_history.csv", index=False)

# Load best Stage 2 weights
model.load_state_dict(torch.load(best_stage2_model_path, map_location=device))

# Evaluate both heads on TEST after Stage 2
print("\n" + "="*60)
print("STAGE 2: Evaluate both heads on TEST set")
print("="*60)
test_loss_att2, test_loss_dev2, test_att_acc2, test_dev_acc2, y_att_true2, y_att_pred2, y_dev_true2, y_dev_pred2 = eval_both_heads(model, test_loader)
print(f"[Stage2] Test Attack Loss: {test_loss_att2:.4f}, Test Attack Acc: {test_att_acc2:.4f}")
print(f"[Stage2] Test Device Loss: {test_loss_dev2:.4f}, Test Device Acc: {test_dev_acc2:.4f}")

print("\n[Stage2] Attack head classification report (TEST):")
stage2_attack_report = classification_report(
    y_att_true2,
    y_att_pred2,
    target_names=attack_names,
    digits=4,
    zero_division=0,
)
print(stage2_attack_report)

# Device: top-20 frequent devices only (to keep report readable)
dev_counts = pd.Series(y_dev_true2).value_counts()
top_k = min(20, len(dev_counts))
top_dev_ids = dev_counts.index[:top_k]
mask_top = np.isin(y_dev_true2, top_dev_ids)
y_dev_true_top = y_dev_true2[mask_top]
y_dev_pred_top = y_dev_pred2[mask_top]
top_dev_names = [device_label_mapping[str(i)] for i in top_dev_ids]

print("\n[Stage2] Device head classification report (TEST) - Top 20 devices:")
stage2_device_report = classification_report(
    y_dev_true_top,
    y_dev_pred_top,
    labels=top_dev_ids,              # ensure labels size matches target_names
    target_names=top_dev_names,
    digits=4,
    zero_division=0,
)
print(stage2_device_report)

with open(REPORTS_DIR / "multitask_stage2_test_reports.txt", "w") as f:
    f.write("Attack head report (TEST):\n")
    f.write(stage2_attack_report)
    f.write("\n\nDevice head report (TEST) - Top 20 devices:\n")
    f.write(stage2_device_report)

# ============================================================
# 10. Stage 3 (optional): joint fine-tuning
# ============================================================

if DO_JOINT_FINETUNE:
    print("\n" + "="*60)
    print("STAGE 3: Joint fine-tuning (attack-dominant loss)")
    print("="*60)

    # Unfreeze entire model
    for param in model.parameters():
        param.requires_grad = True

    optimizer_stage3 = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE_STAGE3,
        weight_decay=WEIGHT_DECAY,
    )
    total_steps_stage3 = len(train_loader) * NUM_EPOCHS_STAGE3
    warmup_steps_stage3 = int(0.1 * total_steps_stage3)  # small warmup
    scheduler_stage3 = get_warmup_cosine_schedule(optimizer_stage3, warmup_steps_stage3, total_steps_stage3)

    best_combined_stage3 = -1.0
    epochs_no_improve3 = 0
    history_stage3 = []
    best_stage3_model_path = MODELS_DIR / "multitask_cnn_mlp_stage3_joint_finetuned.pt"

    def train_epoch_joint(model, loader, optimizer, scheduler, epoch):
        model.train()
        total_loss = 0.0
        total_samples = 0

        for batch_idx, (x, y_att, y_dev) in enumerate(loader):
            x = x.to(device)
            y_att = y_att.to(device)
            y_dev = y_dev.to(device)

            optimizer.zero_grad()
            logits_att, logits_dev = model(x)
            loss_att = criterion_attack(logits_att, y_att)
            loss_dev = criterion_device(logits_dev, y_dev)
            loss = LAMBDA_ATTACK_JOINT * loss_att + LAMBDA_DEVICE_JOINT * loss_dev
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
            scheduler.step()

            batch_size = x.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            if (batch_idx + 1) % 200 == 0:
                lr = scheduler.get_last_lr()[0]
                print(f"[Stage3] Epoch {epoch} | Batch {batch_idx+1}/{len(loader)} | "
                      f"Loss: {loss.item():.4f} (Att: {loss_att.item():.4f}, Dev: {loss_dev.item():.4f}) | LR: {lr:.2e}")

        return total_loss / total_samples

    for epoch in range(1, NUM_EPOCHS_STAGE3 + 1):
        print(f"\n[Stage3] Epoch {epoch}/{NUM_EPOCHS_STAGE3}")
        train_loss3 = train_epoch_joint(model, train_loader, optimizer_stage3, scheduler_stage3, epoch)
        val_loss_att3, val_loss_dev3, val_att_acc3, val_dev_acc3, _, _, _, _ = eval_both_heads(model, val_loader)

        combined_metric = (
            JOINT_METRIC_ATTACK_WEIGHT * val_att_acc3
            + JOINT_METRIC_DEVICE_WEIGHT * val_dev_acc3
        )
        history_stage3.append({
            "epoch": epoch,
            "train_loss_joint": train_loss3,
            "val_loss_att": val_loss_att3,
            "val_loss_dev": val_loss_dev3,
            "val_att_acc": val_att_acc3,
            "val_dev_acc": val_dev_acc3,
            "combined_metric": combined_metric,
            "lr": scheduler_stage3.get_last_lr()[0],
        })

        print(f"  [Stage3] Train Loss: {train_loss3:.4f}")
        print(f"  [Stage3] Val Att Loss: {val_loss_att3:.4f}, Val Att Acc: {val_att_acc3:.4f}")
        print(f"  [Stage3] Val Dev Loss: {val_loss_dev3:.4f}, Val Dev Acc: {val_dev_acc3:.4f}")
        print(f"  [Stage3] Combined metric: {combined_metric:.4f}")

        if combined_metric > best_combined_stage3 + 1e-4:
            best_combined_stage3 = combined_metric
            epochs_no_improve3 = 0
            torch.save(model.state_dict(), best_stage3_model_path)
            print(f"  ✓ [Stage3] New best joint model saved (combined={combined_metric:.4f})")
        else:
            epochs_no_improve3 += 1
            print(f"  [Stage3] No improvement for {epochs_no_improve3} epoch(s)")

        if epochs_no_improve3 >= PATIENCE_STAGE3:
            print("\n[Stage3] Early stopping triggered.")
            break

    pd.DataFrame(history_stage3).to_csv(REPORTS_DIR / "multitask_stage3_joint_history.csv", index=False)

    # Load best Stage 3 model
    model.load_state_dict(torch.load(best_stage3_model_path, map_location=device))

    # Final TEST evaluation after joint fine-tune
    print("\n" + "="*60)
    print("STAGE 3: Final TEST evaluation (joint fine-tuned model)")
    print("="*60)
    test_loss_att3, test_loss_dev3, test_att_acc3, test_dev_acc3, y_att_true3, y_att_pred3, y_dev_true3, y_dev_pred3 = eval_both_heads(model, test_loader)
    print(f"[Stage3] Test Attack Loss: {test_loss_att3:.4f}, Test Attack Acc: {test_att_acc3:.4f}")
    print(f"[Stage3] Test Device Loss: {test_loss_dev3:.4f}, Test Device Acc: {test_dev_acc3:.4f}")

    stage3_attack_report = classification_report(
        y_att_true3,
        y_att_pred3,
        target_names=attack_names,
        digits=4,
        zero_division=0,
    )
    print("\n[Stage3] Attack head classification report (TEST):")
    print(stage3_attack_report)

    dev_counts3 = pd.Series(y_dev_true3).value_counts()
    top_k3 = min(20, len(dev_counts3))
    top_dev_ids3 = dev_counts3.index[:top_k3]
    mask_top3 = np.isin(y_dev_true3, top_dev_ids3)
    y_dev_true_top3 = y_dev_true3[mask_top3]
    y_dev_pred_top3 = y_dev_pred3[mask_top3]
    top_dev_names3 = [device_label_mapping[str(i)] for i in top_dev_ids3]

    stage3_device_report = classification_report(
        y_dev_true_top3,
        y_dev_pred_top3,
        labels=top_dev_ids3,           # ensure labels size matches target_names
        target_names=top_dev_names3,
        digits=4,
        zero_division=0,
    )
    print("\n[Stage3] Device head classification report (TEST) - Top 20 devices:")
    print(stage3_device_report)

    with open(REPORTS_DIR / "multitask_stage3_test_reports.txt", "w") as f:
        f.write("Attack head report (TEST):\n")
        f.write(stage3_attack_report)
        f.write("\n\nDevice head report (TEST) - Top 20 devices:\n")
        f.write(stage3_device_report)

print("\n" + "="*60)
print("Multitask CNN+MLP staged training complete.")
print("="*60)
print("Models saved in:", MODELS_DIR)
print("Histories and reports in:", REPORTS_DIR)

PROJECT_ROOT: /Users/naeemulhassan/naeem-p/Cloud-Deployed-Multitask-IoT-IDS
Using processed data from: /Users/naeemulhassan/naeem-p/Cloud-Deployed-Multitask-IoT-IDS/data/processed
Using device: mps

Loading processed CSVs...
Train shape: (2126280, 139)
Val   shape: (455632, 139)
Test  shape: (455632, 139)
Number of attack classes: 8
Number of device classes: 94

Number of feature columns: 119
Example features: ['stream', 'src_port', 'dst_port', 'inter_arrival_time', 'time_since_previously_displayed_frame', 'port_class_dst', 'l4_tcp', 'l4_udp', 'ttl', 'eth_size', 'tcp_window_size', 'payload_entropy', 'handshake_cipher_suites_length', 'handshake_ciphersuites', 'handshake_extensions_length']

Cleaning NaN/Inf...
  [train] NaN before fill: 20652772, filling with 0.
  [val] NaN before fill: 4432636, filling with 0.
  [test] NaN before fill: 4426542, filling with 0.

Standardising features...
  [train] NaN after std: 0, Inf: 0
  [val] NaN after std: 0, Inf: 0
  [test] NaN after std: 0, Inf: 