# IDS — Tabular → CNN Training Example

This notebook demonstrates training the project's CNN on the tabular IDS dataset
(preprocessing → DataLoader → model → trainer → evaluation).

In [None]:
# Project imports and setup
import sys
from pathlib import Path
sys.path.append(str(Path('..').resolve()))

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import pandas as pd

# Project modules
from src.data.dataset import load_raw_csv
from src.data.preprocess import preprocess_multiclass, preprocess_binary, preprocess_single_sample
from src.models.cnn_model import create_ids_model
from src.training.trainer import Trainer
from src.training.metrics import accuracy, confusion_matrix
from src.utils.helpers import set_seed, get_device, get_optimizer, get_scheduler
from src.utils.visualization import plot_training_history, plot_confusion_matrix


## 1) Setup

In [None]:
# Reproducibility and device
set_seed(42)
device = get_device()
print('Device:', device)

# Quick config for the notebook (tweak as needed)
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 5e-4
MODE = 'multiclass'  # 'binary' or 'multiclass'
TEST_SIZE = 0.2


## 2) Load & Preprocess Data

In [None]:
# Load raw CSVs from data/raw/ (concatenates files)
df = load_raw_csv(data_dir='data/raw')
print('Loaded dataframe with rows:', len(df))

# Choose preprocessing based on MODE
if MODE == 'binary':
    X_train, X_val, y_train, y_val = preprocess_binary(df, test_size=TEST_SIZE)
    num_classes = 2
else:
    X_train, X_val, y_train, y_val = preprocess_multiclass(df, test_size=TEST_SIZE)
    num_classes = int(np.max(y_train) + 1) if len(np.unique(y_train)) > 1 else None

# Convert to torch tensors and reshape to (N, C, H, W)
X_train_t = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)
X_val_t = torch.tensor(X_val, dtype=torch.float32).permute(0, 3, 1, 2)
y_train_t = torch.tensor(y_train, dtype=torch.long)
y_val_t = torch.tensor(y_val, dtype=torch.long)

train_ds = TensorDataset(X_train_t, y_train_t)
val_ds = TensorDataset(X_val_t, y_val_t)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print('Train samples:', len(train_ds))
print('Val samples:', len(val_ds))
print('Num classes:', num_classes)


## 3) Create model

In [None]:
# Create IDS model (factory handles binary/multiclass)
if MODE == 'binary':
    model = create_ids_model(mode='binary', num_classes=2)
else:
    if num_classes is None:
        raise ValueError('num_classes could not be inferred from labels')
    model = create_ids_model(mode='multiclass', num_classes=num_classes)

model = model.to(device)
print(f'Model: {model.__class__.__name__}')
print(f'Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')


## 4) Training setup

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(model=model, optimizer_name='adamw', learning_rate=LEARNING_RATE, weight_decay=1e-4)
scheduler = get_scheduler(optimizer=optimizer, scheduler_name='cosine', epochs=EPOCHS)

trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    checkpoint_dir='models/checkpoints/notebook',
)


## 5) Train

In [None]:
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=EPOCHS,
    early_stopping_patience=8,
    verbose=True,
)


## 6) Visualize training history

In [None]:
plot_training_history(history)


## 7) Evaluate on validation set

In [None]:
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        preds = out.argmax(dim=1).cpu()
        all_preds.append(preds)
        all_targets.append(yb.cpu())

all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)

acc = accuracy(all_preds, all_targets)
print(f'Validation Accuracy: {acc:.2f}%')

# Confusion matrix (numpy) and plot
cm = confusion_matrix(all_preds, all_targets, num_classes=num_classes)
plot_confusion_matrix(cm)


## 8) Save final model (optional)

In [None]:
# Save final model weights into models/final/
import os
os.makedirs('models/final', exist_ok=True)
torch.save({'model_state_dict': model.state_dict()}, 'models/final/final_model_notebook.pth')
print('Saved models/final/final_model_notebook.pth')
