# OBFUS-SIG-NIDS Integration (Obfuscation + SIG-Lite + Bit-Fingerprint)

Lưu ý: Chạy các ô này SAU khi đã có `model`, `train_loader`, `test_loader`, `args.DEVICE` trong notebook (sau huấn luyện hoặc sau khi load model).

- Tầng A: Obfuscation (permute + inverse) quanh lớp cuối, có `reseed()` runtime.
- Tầng B: SIG-Lite (D_KL(u||ŷ) + chuẩn ∂KL/∂W_last, ngưỡng median±k·MAD).
- Tầng C: Bit-Fingerprint (PSI histogram int8 + entropy bit-plane, cảnh báo drift).
- Controller: gộp alert (OR/AND), cooldown, hành động reseed.

Bạn có thể bật/tắt OBFUS-SIG, tinh chỉnh period, k, chuẩn gradient, ngưỡng PSI/entropy ngay trong notebook.


In [1]:
# Standalone OBFUS-SIG runtime (inline, no external imports)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Optional, Dict, Tuple, List, Literal

# --- Tầng A: Obfuscation ---
class ObfusAdapter(nn.Module):
    def __init__(self, size: int, dim: int = 1, seed: Optional[int] = None) -> None:
        super().__init__()
        self.dim = dim
        self.size = int(size)
        perm = self._make_perm(seed)
        self.register_buffer("perm", perm, persistent=False)
        inv = torch.empty_like(perm)
        inv[perm] = torch.arange(self.size, device=perm.device)
        self.register_buffer("inv_perm", inv, persistent=False)

    def _make_perm(self, seed: Optional[int]) -> torch.Tensor:
        gen = torch.Generator()
        if seed is not None:
            gen.manual_seed(int(seed))
        return torch.randperm(self.size, generator=gen)

    @torch.no_grad()
    def reseed(self, seed: Optional[int] = None) -> None:
        perm = self._make_perm(seed).to(self.perm.device)
        self.perm.copy_(perm)
        inv = torch.empty_like(self.perm)
        inv[self.perm] = torch.arange(self.size, device=self.perm.device)
        self.inv_perm.copy_(inv)

    def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
        index = self.inv_perm if inverse else self.perm
        return torch.index_select(x, self.dim, index)

class ObfusPair(nn.Module):
    def __init__(self, child: nn.Module, size: int, dim: int = 1, seed: Optional[int] = None) -> None:
        super().__init__()
        self.child = child
        self.adapter = ObfusAdapter(size=size, dim=dim, seed=seed)
        self.dim = dim

    @torch.no_grad()
    def reseed(self, seed: Optional[int] = None) -> None:
        self.adapter.reseed(seed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.adapter(x, inverse=False)
        y = self.child(x)
        y = self.adapter(y, inverse=True)
        return y

# --- Tầng B: SIG-Lite ---
def _median_and_mad(values: torch.Tensor) -> Tuple[float, float]:
    med = values.median().item()
    mad = (values - values.median()).abs().median().item()
    return med, mad

def _find_last_linear(model: nn.Module) -> Optional[nn.Linear]:
    last = None
    for m in model.modules():
        if isinstance(m, nn.Linear):
            last = m
    return last

class SigLiteMonitor:
    def __init__(
        self,
        model: nn.Module,
        probe_loader: torch.utils.data.DataLoader,
        period: int = 500,
        k: float = 3.0,
        device: Optional[torch.device] = None,
        grad_norm_type: Literal["l1","l2"] = "l1",
        normalize_by_params: bool = True,
    ) -> None:
        self.model = model
        self.probe_loader = probe_loader
        self.period = int(period)
        self.k = float(k)
        self.device = device or next(model.parameters()).device
        self.last_layer = _find_last_linear(model)
        if self.last_layer is None:
            raise ValueError("SigLiteMonitor requires a model with a final nn.Linear layer.")
        self._probe_iter = None
        self.kl_med = self.kl_mad = self.gn_med = self.gn_mad = None
        self.grad_norm_type = grad_norm_type
        self.normalize_by_params = normalize_by_params

    def _next_probe_batch(self):
        if self._probe_iter is None:
            self._probe_iter = iter(self.probe_loader)
        try:
            batch = next(self._probe_iter)
        except StopIteration:
            self._probe_iter = iter(self.probe_loader)
            batch = next(self._probe_iter)
        if isinstance(batch, (list, tuple)) and len(batch) >= 2:
            x, y = batch[0], batch[1]
        else:
            x, y = batch, None
        return x.to(self.device), (y.to(self.device) if y is not None else None)

    @torch.no_grad()
    def _probs(self, logits: torch.Tensor) -> torch.Tensor:
        return F.softmax(logits, dim=-1)

    def _kl_uniform(self, probs: torch.Tensor) -> torch.Tensor:
        c = probs.size(-1)
        u = 1.0 / float(c)
        eps = 1e-8
        log_term = np.log(u + eps) - torch.log(probs + eps)
        return (u * log_term).sum(dim=-1).mean()

    def _grad_norm(self, kl: torch.Tensor) -> torch.Tensor:
        for p in self.model.parameters():
            if p.grad is not None:
                p.grad = None
        grad_w = torch.autograd.grad(kl, self.last_layer.weight, retain_graph=False, allow_unused=False)[0]
        if self.grad_norm_type == "l2":
            norm = torch.linalg.vector_norm(grad_w, ord=2)
        else:
            norm = grad_w.abs().sum()
        if self.normalize_by_params and grad_w.numel() > 0:
            norm = norm / float(grad_w.numel())
        return norm

    def fit_baseline(self, steps: int = 50) -> Dict[str, float]:
        self.model.eval()
        kl_vals, gn_vals = [], []
        with torch.enable_grad():
            for _ in range(steps):
                x, _ = self._next_probe_batch()
                logits = self.model(x)
                probs = self._probs(logits)
                kl = self._kl_uniform(probs)
                gn = self._grad_norm(kl)
                kl_vals.append(kl.detach())
                gn_vals.append(gn.detach())
        kl_t = torch.stack(kl_vals)
        gn_t = torch.stack(gn_vals)
        self.kl_med, self.kl_mad = _median_and_mad(kl_t)
        self.gn_med, self.gn_mad = _median_and_mad(gn_t)
        return {"kl_median": self.kl_med, "kl_mad": self.kl_mad, "grad_norm_median": self.gn_med, "grad_norm_mad": self.gn_mad}

    def _thresholds(self):
        if any(v is None for v in [self.kl_med, self.kl_mad, self.gn_med, self.gn_mad]):
            raise RuntimeError("Call fit_baseline() before using SigLiteMonitor.")
        kl_L = self.kl_med - self.k * self.kl_mad
        kl_U = self.kl_med + self.k * self.kl_mad
        gn_L = self.gn_med - self.k * self.gn_mad
        gn_U = self.gn_med + self.k * self.gn_mad
        return (kl_L, kl_U), (gn_L, gn_U)

    def step(self, batch_idx: int) -> Dict[str, float]:
        if (batch_idx % self.period) != 0:
            return {"ran": 0, "alert": 0}
        self.model.eval()
        with torch.enable_grad():
            x, _ = self._next_probe_batch()
            logits = self.model(x)
            probs = self._probs(logits)
            kl = self._kl_uniform(probs)
            gn = self._grad_norm(kl)
        (kl_L, kl_U), (gn_L, gn_U) = self._thresholds()
        kl_val = float(kl.detach().cpu().item())
        gn_val = float(gn.detach().cpu().item())
        kl_alert = int(kl_val < kl_L or kl_val > kl_U)
        gn_alert = int(gn_val < gn_L or gn_val > gn_U)
        alert = int(kl_alert or gn_alert)
        return {"ran": 1, "alert": alert, "kl": kl_val, "grad_norm_l1": gn_val, "kl_L": kl_L, "kl_U": kl_U, "gn_L": gn_L, "gn_U": gn_U}

# --- Tầng C: Bit-Fingerprint ---
class BitFingerprint:
    def __init__(self, model: nn.Module, threshold_psi: float = 0.1, threshold_entropy: float = 0.15) -> None:
        self.model = model
        self.threshold_psi = float(threshold_psi)
        self.threshold_entropy = float(threshold_entropy)
        self.baseline_hists: Dict[str, np.ndarray] = {}
        self.baseline_entropy: Dict[str, np.ndarray] = {}

    @staticmethod
    def _quantize_int8(t: torch.Tensor) -> np.ndarray:
        arr = t.detach().float().cpu().numpy()
        scale = np.max(np.abs(arr))
        if scale <= 1e-12:
            scale = 1.0
        q = np.clip(np.round(arr / (scale / 127.0)), -128, 127).astype(np.int8)
        return q

    @staticmethod
    def _hist256(q: np.ndarray) -> np.ndarray:
        counts, _ = np.histogram(q.astype(np.int16), bins=256, range=(-128, 128))
        return counts.astype(np.float64)

    @staticmethod
    def _psi(p: np.ndarray, q: np.ndarray, eps: float = 1e-8) -> float:
        p = p.astype(np.float64); q = q.astype(np.float64)
        p = p / (p.sum() + eps)
        q = q / (q.sum() + eps)
        return float(np.sum((p - q) * np.log((p + eps) / (q + eps))))

    @staticmethod
    def _bit_plane_entropy(q: np.ndarray) -> np.ndarray:
        q_u = q.view(np.uint8)
        ent = []
        for b in range(8):
            bits = (q_u >> b) & 1
            p1 = bits.mean(); p0 = 1.0 - p1
            eps = 1e-12
            if 0.0 < p1 < 1.0:
                h = -(p1 * np.log2(p1 + eps) + p0 * np.log2(p0 + eps))
            else:
                h = 0.0
            ent.append(h)
        return np.array(ent, dtype=np.float64)

    def build_baseline(self) -> Dict[str, float]:
        self.baseline_hists.clear(); self.baseline_entropy.clear()
        num_params = 0
        for name, p in self.model.named_parameters():
            if p.data is None:
                continue
            q = self._quantize_int8(p.data)
            h = self._hist256(q)
            self.baseline_hists[name] = h
            self.baseline_entropy[name] = self._bit_plane_entropy(q)
            num_params += 1
        return {"num_params": float(num_params)}

    def update(self) -> Dict[str, object]:
        if not self.baseline_hists:
            raise RuntimeError("Call build_baseline() before update().")
        psi_per_layer: Dict[str, float] = {}
        entropy_drift_per_layer: Dict[str, float] = {}
        alert_layers: List[str] = []
        entropy_alert_layers: List[str] = []
        for name, p in self.model.named_parameters():
            if name not in self.baseline_hists:
                continue
            q = self._quantize_int8(p.data)
            h = self._hist256(q)
            psi_val = self._psi(self.baseline_hists[name], h)
            psi_per_layer[name] = psi_val
            if psi_val > self.threshold_psi:
                alert_layers.append(name)
            cur_ent = self._bit_plane_entropy(q)
            base_ent = self.baseline_entropy.get(name, cur_ent)
            drift = float(np.max(np.abs(cur_ent - base_ent)))
            entropy_drift_per_layer[name] = drift
            if drift > self.threshold_entropy:
                entropy_alert_layers.append(name)
        max_psi = max(psi_per_layer.values()) if psi_per_layer else 0.0
        max_entropy_drift = max(entropy_drift_per_layer.values()) if entropy_drift_per_layer else 0.0
        any_alert = int(len(alert_layers) > 0 or len(entropy_alert_layers) > 0)
        return {
            "alert": any_alert,
            "max_psi": float(max_psi),
            "max_entropy_drift": float(max_entropy_drift),
            "psi_per_layer": psi_per_layer,
            "alert_layers": alert_layers,
            "entropy_drift_per_layer": entropy_drift_per_layer,
            "entropy_alert_layers": entropy_alert_layers,
        }

# --- Controller ---
class ControllerPolicy:
    def __init__(self, alert_mode: str = "or", cooldown_steps: int = 1000) -> None:
        assert alert_mode in ("or", "and")
        self.alert_mode = alert_mode
        self.cooldown_steps = int(cooldown_steps)
        self._last_action_step = -10**9
        self._step = 0
        self._adapters: List[ObfusPair] = []
        self._logs: List[Dict[str, object]] = []

    def register_adapters(self, adapters: List[ObfusPair]) -> None:
        self._adapters = adapters

    def _should_act(self, alerts: Dict[str, int]) -> bool:
        vals = list(alerts.values())
        if not vals:
            return False
        fire = any(v > 0 for v in vals) if self.alert_mode == "or" else all(v > 0 for v in vals)
        return fire and (self._step - self._last_action_step >= self.cooldown_steps)

    def step(self, metrics: Dict[str, object]) -> Dict[str, object]:
        alerts = {"sig": int(metrics.get("sig_alert", 0)), "fp": int(metrics.get("fp_alert", 0))}
        action_taken = "none"
        if self._should_act(alerts):
            for a in self._adapters:
                a.reseed()
            action_taken = "reseed_adapters"
            self._last_action_step = self._step
        self._logs.append({"t": time.time(), "step": self._step, "alerts": alerts, "action": action_taken, "extras": metrics})
        self._step += 1
        return {"action": action_taken, "step": self._step - 1}

# --- Runtime ---
class ObfusSigRuntime:
    def __init__(
        self,
        model: nn.Module,
        probe_loader: torch.utils.data.DataLoader,
        alert_mode: str = "or",
        sig_period: int = 500,
        sig_k: float = 3.0,
        fp_threshold: float = 0.1,
        fp_entropy_threshold: float = 0.15,
        grad_norm_type: str = "l1",
        normalize_grad: bool = True,
        make_shadow: bool = False,
        device: Optional[torch.device] = None,
    ) -> None:
        self.device = device or next(model.parameters()).device
        self.model = model
        # Wrap last Linear with obfuscation if present
        last_name = None; last_linear = None
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                last_name, last_linear = name, module
        if last_linear is not None and last_name is not None:
            parent = self.model
            path = last_name.split(".")
            for p in path[:-1]:
                parent = getattr(parent, p)
            leaf = path[-1]
            wrapped = ObfusPair(child=last_linear, size=last_linear.in_features, dim=1, seed=None)
            setattr(parent, leaf, wrapped)
        # Collect adapters
        self.adapters = [m for m in self.model.modules() if isinstance(m, ObfusPair)]
        # Monitors
        self.sig = SigLiteMonitor(self.model, probe_loader, period=sig_period, k=sig_k, device=self.device, grad_norm_type=grad_norm_type, normalize_by_params=normalize_grad)
        self.fp = BitFingerprint(self.model, threshold_psi=fp_threshold, threshold_entropy=fp_entropy_threshold)
        self.ctrl = ControllerPolicy(alert_mode=alert_mode, cooldown_steps=max(2, sig_period))
        self.ctrl.register_adapters(self.adapters)

    def calibrate(self, sig_steps: int = 50) -> Dict[str, float]:
        fp_stats = self.fp.build_baseline()
        sig_stats = self.sig.fit_baseline(steps=sig_steps)
        return {**fp_stats, **sig_stats}

    def periodic_check(self, batch_idx: int) -> Dict[str, object]:
        sig_ret = self.sig.step(batch_idx)
        fp_ret = self.fp.update()
        ctrl_ret = self.ctrl.step({
            "sig_alert": int(sig_ret.get("alert", 0)),
            "fp_alert": int(fp_ret.get("alert", 0)),
            "sig": sig_ret,
            "fp": {
                "max_psi": fp_ret.get("max_psi", 0.0),
                "max_entropy_drift": fp_ret.get("max_entropy_drift", 0.0),
            },
        })
        return {"sig": sig_ret, "fp": fp_ret, "ctrl": ctrl_ret}

# Xác định thiết bị độc lập
try:
    device = args.DEVICE
except Exception:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)


Device: cuda


In [None]:
# Build ObfusSigRuntime, calibrate on clean, and evaluate with periodic checks
from torch.utils.data import DataLoader

assert 'model' in globals(), 'Cần có biến model (đã huấn luyện hoặc đã load).'
assert 'train_loader' in globals() and 'test_loader' in globals(), 'Cần train_loader và test_loader.'

# Cấu hình mặc định có thể tinh chỉnh
OBFUS_CFG = dict(
    alert_mode='or',        # 'or' bắt sớm, 'and' giảm FAR
    sig_period=500,
    sig_k=3.0,
    fp_threshold=0.10,
    fp_entropy_threshold=0.15,
    grad_norm_type='l1',
    normalize_grad=True,
    make_shadow=False,
)

runtime = ObfusSigRuntime(
    model=model,
    probe_loader=train_loader,
    device=device,
    **OBFUS_CFG,
)
cal_stats = runtime.calibrate(sig_steps=50)
print('Calibrated:', cal_stats)

@torch.no_grad()
def evaluate_with_runtime(model, loader, runtime):
    model.eval()
    total, correct = 0, 0
    for step, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        # Kiểm tra định kỳ (SIG-Lite + Fingerprint + Controller)
        runtime.periodic_check(step)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total * 100

acc_rt = evaluate_with_runtime(model, test_loader, runtime)
print(f'Accuracy (with runtime probe): {acc_rt:.2f}%')


In [None]:
# Bit-flip attack helpers (PBS/RandomFlip) for quantized models
import csv
import numpy as np

# Tái sử dụng các lớp lượng tử hoá đã định nghĩa trong notebook: quan_Conv1d, quan_Linear, CustomBlock

def _is_quant_module(module):
    from types import SimpleNamespace
    return isinstance(module, (quan_Conv1d, quan_Linear, CustomBlock))

@torch.no_grad()
def _flip_one_bit_in_module_weight(module, element_index=None, bit_index=None):
    if not hasattr(module, 'weight') or not hasattr(module, 'N_bits') or not hasattr(module, 'step_size'):
        return None
    weight = module.weight.data
    n_bits = int(getattr(module, 'N_bits', 8))
    step_size_tensor = getattr(module, 'step_size')
    step_size = step_size_tensor.detach().float().view(-1)[0].item() if isinstance(step_size_tensor, torch.Tensor) else float(step_size_tensor)
    if step_size == 0.0:
        max_abs = weight.abs().max().item() if weight.numel() > 0 else 1.0
        step_size = max(1e-6, max_abs / max(1.0, (2 ** n_bits - 2) / 2.0))

    flat = weight.view(-1)
    numel = flat.numel()
    if numel == 0:
        return None
    if element_index is None:
        element_index = int(torch.randint(low=0, high=numel, size=(1,)).item())
    if bit_index is None:
        bit_index = int(torch.randint(low=0, high=n_bits, size=(1,)).item())

    w_int = torch.round(flat / step_size).to(torch.int32)
    unsigned = w_int.clone()
    neg_mask = unsigned < 0
    unsigned[neg_mask] = (1 << n_bits) + unsigned[neg_mask]
    toggle_mask = 1 << bit_index
    before_unsigned = unsigned[element_index].item()
    unsigned[element_index] = (unsigned[element_index].int() ^ toggle_mask)
    after_unsigned = unsigned[element_index].item()
    mask = (1 << (n_bits - 1)) - 1
    signed = -(unsigned & ~mask) + (unsigned & mask)
    w_new = signed.to(flat.dtype) * step_size
    old_val = flat[element_index].item()
    new_val = w_new[element_index].item()
    flat.copy_(w_new)
    return (old_val, new_val, element_index, bit_index)

@torch.no_grad()
def _get_quant_modules(model):
    modules = []
    for name, module in model.named_modules():
        if _is_quant_module(module):
            modules.append((name, module))
    return modules

@torch.no_grad()
def _random_flip_one_bit(model):
    modules = _get_quant_modules(model)
    if not modules:
        return None
    name, module = modules[torch.randint(low=0, high=len(modules), size=(1,)).item()]
    result = _flip_one_bit_in_module_weight(module)
    return {'module': name, 'result': result}

@torch.no_grad()
def _compute_batch_loss(model, criterion, batch_x, batch_y):
    logits = model(batch_x)
    loss = criterion(logits, batch_y)
    return float(loss.item())

@torch.no_grad()
def _progressive_bit_search(model, criterion, calib_x, calib_y, max_trials=16):
    modules = _get_quant_modules(model)
    if not modules:
        return None
    base_loss = _compute_batch_loss(model, criterion, calib_x, calib_y)
    best = {'delta': 0.0, 'apply': None, 'where': None}
    trials = min(max_trials, sum(m.weight.data.numel() > 0 for _, m in modules))
    for _ in range(trials):
        name, module = modules[torch.randint(low=0, high=len(modules), size=(1,)).item()]
        weight = module.weight.data
        if weight.numel() == 0:
            continue
        elem_idx = int(torch.randint(low=0, high=weight.numel(), size=(1,)).item())
        bit_idx = int(torch.randint(low=0, high=int(getattr(module, 'N_bits', 8)), size=(1,)).item())
        old_val = weight.view(-1)[elem_idx].item()
        flip_info = _flip_one_bit_in_module_weight(module, elem_idx, bit_idx)
        trial_loss = _compute_batch_loss(model, criterion, calib_x, calib_y)
        delta = trial_loss - base_loss
        # revert
        weight.view(-1)[elem_idx] = torch.tensor(old_val, dtype=weight.dtype, device=weight.device)
        if delta > best['delta']:
            best = {'delta': delta, 'apply': (name, elem_idx, bit_idx), 'where': flip_info}
    if best['apply'] is not None and best['delta'] > 0:
        name, elem_idx, bit_idx = best['apply']
        for n, m in modules:
            if n == name:
                flip_result = _flip_one_bit_in_module_weight(m, elem_idx, bit_idx)
                return {'module': name, 'elem_idx': elem_idx, 'bit_idx': bit_idx, 'delta_loss': best['delta'], 'result': flip_result}
    return _random_flip_one_bit(model)


In [None]:
# Attack loop with per-iteration logging (integrated with OBFUS-SIG runtime)
import os
from tqdm import tqdm

# Chọn mode và số vòng lặp
ATTACK_MODE = 'pbs'  # 'pbs' | 'random_flip' | 'pbs_to_random' | 'random_to_pbs'
ATTACK_ITERS = 25

# Chuẩn bị batch hiệu chuẩn cho PBS
criterion = torch.nn.CrossEntropyLoss()
calib_batch = next(iter(train_loader))
calib_x, calib_y = calib_batch[0].to(device), calib_batch[1].to(device)

iter_logs = []
os.makedirs('results/defense_results', exist_ok=True)

for i in range(int(ATTACK_ITERS)):
    if ATTACK_MODE == 'pbs':
        info = _progressive_bit_search(model, criterion, calib_x, calib_y, max_trials=16)
    elif ATTACK_MODE == 'random_flip':
        info = _random_flip_one_bit(model)
    elif ATTACK_MODE == 'pbs_to_random':
        _ = _progressive_bit_search(model, criterion, calib_x, calib_y, max_trials=16)
        info = _random_flip_one_bit(model)
    elif ATTACK_MODE == 'random_to_pbs':
        _ = _random_flip_one_bit(model)
        info = _progressive_bit_search(model, criterion, calib_x, calib_y, max_trials=16)
    else:
        info = _random_flip_one_bit(model)

    # Gọi kiểm tra OBFUS-SIG ở bước hiện tại
    ret = runtime.periodic_check(i + 1)

    # Đánh giá accuracy sau vòng tấn công
    acc_i = 0.0
    model.eval()
    total_i, correct_i = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(1)
            correct_i += (pred == y).sum().item()
            total_i += y.size(0)
    if total_i > 0:
        acc_i = correct_i / total_i * 100

    # Trích chi tiết flip
    module_name = info.get('module') if isinstance(info, dict) else None
    old_val = new_val = elem_idx = bit_idx = None
    if isinstance(info, dict):
        if 'result' in info and info['result'] is not None:
            try:
                old_val, new_val, elem_idx, bit_idx = info['result']
            except Exception:
                pass
        else:
            elem_idx = info.get('elem_idx')
            bit_idx = info.get('bit_idx')

    sig = ret.get('sig', {})
    fp = ret.get('fp', {})

    iter_logs.append([
        i + 1,
        ATTACK_MODE,
        module_name,
        old_val,
        new_val,
        elem_idx,
        bit_idx,
        f"{acc_i:.4f}",
        int(sig.get('alert', 0)),
        float(sig.get('kl', 0.0)) if sig.get('ran', 0) else '',
        float(sig.get('grad_norm_l1', 0.0)) if sig.get('ran', 0) else '',
        int(fp.get('alert', 0)),
        float(fp.get('max_psi', 0.0)),
        float(fp.get('max_entropy_drift', 0.0)),
    ])

csv_path = f"results/defense_results/IoTID20_OBFUS_SIG_{ATTACK_MODE}_iterlog.csv"
with open(csv_path, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow([
        'iteration','mode','module','old_val','new_val','elem_idx','bit_idx',
        'accuracy_after_iter','sig_alert','kl','grad_norm','fp_alert','max_psi','max_entropy_drift'
    ])
    writer.writerows(iter_logs)

print('Saved CSV:', csv_path)



In [None]:
# Override build_model to correct class names
class_names = ['mlp','custom1','custom2']

def build_model(model_name, in_features, n_classes):
    if model_name == 'mlp':
        return MLPClassifier(in_features, n_classes)
    elif model_name == 'custom1':
        # Map to CustomModel (đã định nghĩa ở trên)
        return CustomModel(input_size=in_features, output_size=n_classes)
    elif model_name == 'custom2':
        return CustomModel2(input_size=in_features, output_size=n_classes)
    else:
        raise ValueError('Unknown model: ' + model_name)

print('build_model() has been overridden. Available:', class_names)


In [None]:
# Override train/evaluate to adapt input shape for Conv1d models
import torch.nn.functional as F

from typing import Optional

def _first_conv1d(module) -> Optional[torch.nn.Conv1d]:
    for m in module.modules():
        if isinstance(m, torch.nn.Conv1d):
            return m
    return None

def _format_input_for_model(model, x):
    # x: [B, features]
    conv = _first_conv1d(model)
    if conv is None:
        return x  # MLP hoặc linear/tabular
    if x.dim() == 2:
        # Ưu tiên dạng [B, C=features, L=1] nếu in_channels == features
        if conv.in_channels == x.size(1):
            return x.unsqueeze(-1)  # [B, features, 1]
        # Nếu in_channels == 1, coi features là chiều dài
        if conv.in_channels == 1:
            return x.unsqueeze(1)  # [B, 1, features]
    return x

@torch.no_grad()
def evaluate(model, loader, device='cpu'):
    model.eval()
    total, correct = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        x = _format_input_for_model(model, x)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total * 100

def train(model, train_loader, test_loader, epochs, lr, device):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total, correct, loss_sum = 0, 0, 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x = _format_input_for_model(model, x)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item() * y.size(0)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        
        train_acc = correct / total * 100
        test_acc = evaluate(model, test_loader, device)
        print(f'[Epoch {epoch+1:03d}] Train Acc: {train_acc:6.2f}% | Test Acc: {test_acc:6.2f}% | Loss: {loss_sum/total:.4f}')
    
    print('\nHoàn tất huấn luyện!')
    return model

print('evaluate()/train() overridden with Conv1d input-shape adaptation.')


In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import urllib.request
import glob
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import math
import torch.nn.init as init

In [None]:
class Config:
    # --- THAM SỐ BẮT BUỘC ---
    # !!! THAY ĐỔI ĐƯỜNG DẪN NÀY tới thư mục chứa file CSV của bạn
    DATA_ROOT = '../dataset'

    MODEL_NAME = 'mlp'  # Lựa chọn giữa 'mlp', 'custom1', 'custom2'
    EPOCHS = 15
    LEARNING_RATE = 1e-3
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    OUTPUT_PATH = 'save/best_model.pth'
    
    # Các tham số cho việc tải dữ liệu (để trống nếu không dùng)
    SOURCE_CSV = None
    DOWNLOAD_URL = None

args = Config()
print(f"Sử dụng thiết bị: {args.DEVICE}")

In [None]:
class IoTID20CSVDataset(Dataset):
    def __init__(self, csv_file, feature_cols=None, label_col='label', scaler=None, label_encoder=None):
        df = pd.read_csv(csv_file)
        if feature_cols is None:
            feature_cols = [c for c in df.columns if c != label_col]
        self.feature_cols = feature_cols
        self.label_col = label_col
        X = df[feature_cols].values.astype(np.float32)
        y = df[label_col].values

        if scaler is None:
            scaler = StandardScaler()
            X = scaler.fit_transform(X)
        else:
            X = scaler.transform(X)
        self.scaler = scaler

        if label_encoder is None:
            label_encoder = LabelEncoder()
            y = label_encoder.fit_transform(y)
        else:
            y = label_encoder.transform(y)
        self.label_encoder = label_encoder

        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y.astype(np.int64))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


def _standardize_and_split_source_csv(data_root, out_train, out_test, source_csv=None):
    candidates = []
    if source_csv and os.path.isfile(source_csv):
        candidates = [source_csv]
    else:
        candidates = [p for p in glob.glob(os.path.join(data_root, '*.csv')) if os.path.basename(p).lower() not in {'train.csv', 'test.csv'}]
    if not candidates:
        return False
    
    print(f"Đang xử lý file CSV nguồn: {candidates[0]}")
    df = pd.read_csv(candidates[0], skipinitialspace=True)
    df = df.drop_duplicates()
    for col in ['Flow_ID', 'Src_IP', 'Dst_IP', 'Timestamp', 'Fwd_PSH_Flags', 'Fwd_URG_Flags', 'Fwd_Byts/b_Avg', 'Fwd_Pkts/b_Avg', 'Fwd_Blk_Rate_Avg', 'Bwd_Byts/b_Avg', 'Bwd_Pkts/b_Avg', 'Bwd_Blk_Rate_Avg', 'Init_Fwd_Win_Byts', 'Fwd_Seg_Size_Min']:
        if col in df.columns:
            df = df.drop(columns=[col])
    df.replace([np.inf, -np.inf], np.nan, inplace=True)
    df.fillna(0, inplace=True)
    df = df.drop_duplicates()
    label_col_guess = 'Cat' if 'Cat' in df.columns else ('Label' if 'Label' in df.columns else 'label')
    labels = df[[label_col_guess]].copy()
    features = df.drop(columns=[c for c in ['Label', 'Cat', 'Sub_Cat'] if c in df.columns], errors='ignore')
    
    train_df, test_df = train_test_split(pd.concat([features, labels], axis=1), test_size=0.2, random_state=100)
    train_df = train_df.rename(columns={label_col_guess: 'label'})
    test_df = test_df.rename(columns={label_col_guess: 'label'})
    
    train_df.to_csv(out_train, index=False)
    test_df.to_csv(out_test, index=False)
    print(f"Đã tạo file train.csv và test.csv tại {data_root}")
    return True

def _download_csv(download_url: str, dest_path: str) -> bool:
    try:
        os.makedirs(os.path.dirname(dest_path), exist_ok=True)
        print(f"Đang tải dữ liệu từ {download_url}...")
        urllib.request.urlretrieve(download_url, dest_path)
        print("Tải xong!")
        return True
    except Exception as e:
        print(f"Tải thất bại: {e}")
        return False

def build_iotid20_loaders(data_root, batch_size=256, num_workers=0, source_csv=None, download_url: str = None):
    train_csv = os.path.join(data_root, 'train.csv')
    test_csv = os.path.join(data_root, 'test.csv')
    if not os.path.exists(train_csv) or not os.path.exists(test_csv):
        ok = _standardize_and_split_source_csv(data_root, train_csv, test_csv, source_csv=source_csv)
        if not ok and download_url:
            downloaded = _download_csv(download_url, os.path.join(data_root, 'IoT_Network_Intrusion_Dataset.csv'))
            if downloaded:
                ok = _standardize_and_split_source_csv(data_root, train_csv, test_csv, source_csv=os.path.join(data_root, 'IoT_Network_Intrusion_Dataset.csv'))
        if not ok:
            raise FileNotFoundError(f'Không tìm thấy file train.csv/test.csv hoặc file CSV nguồn tại {data_root}.')

    label_col = 'label' if 'label' in pd.read_csv(train_csv, nrows=1).columns else 'Cat'
    train_ds = IoTID20CSVDataset(train_csv, label_col=label_col)
    test_ds = IoTID20CSVDataset(test_csv, feature_cols=train_ds.feature_cols, label_col=label_col, scaler=train_ds.scaler, label_encoder=train_ds.label_encoder)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    n_features = train_ds.X.shape[1]
    n_classes = len(train_ds.label_encoder.classes_)
    
    print(f"Tạo Dataloaders thành công: {n_features} features, {n_classes} classes.")
    return train_loader, test_loader, n_features, n_classes

print("Các hàm và lớp xử lý dữ liệu đã sẵn sàng.")

In [None]:
class quan_Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(quan_Linear, self).__init__(in_features, out_features, bias=bias)

        self.N_bits = 8
        self.full_lvls = 2**self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2
        # Initialize the step size
        self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
        self.__reset_stepsize__()
        # flag to enable the inference with quantized weight or self.weight
        self.inf_with_weight = False  # disabled by default

        # create a vector to identify the weight to each bit
        self.b_w = nn.Parameter(2**torch.arange(start=self.N_bits - 1,
                                                end=-1,
                                                step=-1).unsqueeze(-1).float(),
                                requires_grad=False)

        self.b_w[0] = -self.b_w[0]  #in-place reverse

    def forward(self, input):
        if self.inf_with_weight:
            return F.linear(input, self.weight * self.step_size, self.bias)
        else:
            self.__reset_stepsize__()
            weight_quan = quantize(self.weight, self.step_size,
                                   self.half_lvls) * self.step_size
            return F.linear(input, weight_quan, self.bias)

    def __reset_stepsize__(self):
        with torch.no_grad():
            self.step_size.data = self.weight.abs().max() / self.half_lvls

    def __reset_weight__(self):
        '''
        This function will reconstruct the weight stored in self.weight.
        Replacing the orginal floating-point with the quantized fix-point
        weight representation.
        '''
        # replace the weight with the quantized version
        with torch.no_grad():
            self.weight.data = quantize(self.weight, self.step_size,
                                        self.half_lvls)
        # enable the flag, thus now computation does not invovle weight quantization
        self.inf_with_weight = True


In [None]:
class quan_Conv1d(nn.Conv1d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 bias=True):
        super(quan_Conv1d, self).__init__(in_channels,
                                          out_channels,
                                          kernel_size,
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation,
                                          groups=groups,
                                          bias=bias)

        # Số lượng bit để lượng tử hóa trọng số
        self.N_bits = 8
        self.full_lvls = 2 ** self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2

        # Bước lượng tử hóa (step size), là một tham số có thể học được
        self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
        self.__reset_stepsize__()

        # Cờ để bật hoặc tắt sử dụng trọng số lượng tử hóa
        self.inf_with_weight = False  # Tắt theo mặc định

        # Tạo một vector để biểu diễn trọng số cho từng bit
        self.b_w = nn.Parameter(2 ** torch.arange(start=self.N_bits - 1,
                                                  end=-1,
                                                  step=-1).unsqueeze(-1).float(),
                                requires_grad=False)
        self.b_w[0] = -self.b_w[0]  # Biến đổi MSB thành giá trị âm để hỗ trợ bù hai

    def __reset_stepsize__(self):
        """Hàm này dùng để đặt lại giá trị `step_size`."""
        # Giá trị này có thể được tùy chỉnh tùy thuộc vào yêu cầu của mô hình
        self.step_size.data.fill_(1.0)

    def forward(self, x):
        # Kiểm tra cờ `inf_with_weight` để quyết định sử dụng trọng số đã lượng tử hóa hay không
        if self.inf_with_weight:
            quantized_weight = self.quantize_weight(self.weight)
            return nn.functional.conv1d(x, quantized_weight, self.bias, self.stride,
                                        self.padding, self.dilation, self.groups)
        else:
            return nn.functional.conv1d(x, self.weight, self.bias, self.stride,
                                        self.padding, self.dilation, self.groups)

    def quantize_weight(self, weight):
        """Lượng tử hóa trọng số theo số bit đã định."""
        # Tạo trọng số lượng tử hóa bằng cách sử dụng step_size
        quantized_weight = torch.round(weight / self.step_size) * self.step_size
        quantized_weight = torch.clamp(quantized_weight, -self.half_lvls * self.step_size,
                                       (self.half_lvls - 1) * self.step_size)
        return quantized_weight

In [None]:
def change_quan_bitwidth(model, n_bit):
    '''This script change the quantization bit-width of entire model to n_bit'''
    for m in model.modules():
        if isinstance(m, quan_Conv1d) or isinstance(m, quan_Linear):
            m.N_bits = n_bit
            # print("Change weight bit-width as {}.".format(m.N_bits))
            m.b_w.data = m.b_w.data[-m.N_bits:]
            m.b_w[0] = -m.b_w[0]
            print(m.b_w)
    return 

In [None]:

class _quantize_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, step_size, half_lvls):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.step_size = step_size
        ctx.half_lvls = half_lvls
        output = F.hardtanh(input,
                            min_val=-ctx.half_lvls * ctx.step_size.item(),
                            max_val=ctx.half_lvls * ctx.step_size.item())

        output = torch.round(output / ctx.step_size)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone() / ctx.step_size

        return grad_input, None, None

quantize = _quantize_func.apply

In [None]:
class _bin_func(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, mu):
        ctx.mu = mu
        output = input.clone().zero_()
        output[input.ge(0)] = 1
        output[input.lt(0)] = -1

        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone() / ctx.mu
        return grad_input, None

w_bin = _bin_func.apply

In [None]:
class CustomBlock(nn.Module):
    def __init__(self, in_features, out_features, bias=True, apply_softmax=False):
        super(CustomBlock, self).__init__()
        self.N_bits = 16
        self.full_lvls = 2 ** self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2
        self.apply_softmax = apply_softmax

        # Initialize the step size
        self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)

        # Initialize weights and bias
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features)) if bias else None

        # Reset parameters
        self.__reset_stepsize__()
        self.reset_parameters()

        # Flag for inference with quantized weights
        self.inf_with_weight = False

        self.b_w = nn.Parameter(2**torch.arange(start=self.N_bits - 1,
                                             end=-1,
                                             step=-1).unsqueeze(-1).float(),
                           requires_grad=False)
        self.b_w[0] = -self.b_w[0]  #in-place reverse
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, input):
        if self.inf_with_weight:
            weight_applied = self.weight * self.step_size
        else:
            self.__reset_stepsize__()
            weight_quan = quantize(self.weight, self.step_size, self.half_lvls) * self.step_size
            weight_applied = weight_quan

        # Linear transformation
        input = input.view(input.size(0), -1)  # Flatten input to 2D for matmul
        output = input @ weight_applied.T
        if self.bias is not None:
            output += self.bias

        if self.apply_softmax:
            output = F.softmax(output, dim=-1)
        return output

    def __reset_stepsize__(self):
        with torch.no_grad():
            self.step_size.data = self.weight.abs().max() / self.half_lvls

    def __reset_weight__(self):
        with torch.no_grad():
            self.weight.data = quantize(self.weight, self.step_size, self.half_lvls)
        self.inf_with_weight = True


In [None]:
class DownsampleA(nn.Module):
    def __init__(self, nIn, nOut, stride):
        super(DownsampleA, self).__init__()
        assert stride == 2
        self.avg = nn.AvgPool1d(kernel_size=1, stride=stride)

    def forward(self, x):
        x = self.avg(x)
        return torch.cat((x, x.mul(0)), 1)
        
class SEBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(SEBlock, self).__init__()
        self.conv_a = quan_Conv1d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_a = nn.BatchNorm1d(planes)
        self.dropout_a = nn.Dropout(p=0.3)  # Dropout sau BatchNorm

        self.conv_b = quan_Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_b = nn.BatchNorm1d(planes)
        self.dropout_b = nn.Dropout(p=0.3)  # Dropout sau BatchNorm

        self.downsample = downsample

    def forward(self, x):
        residual = x
        
        basicblock = self.conv_a(x)
        basicblock = self.bn_a(basicblock)
        basicblock = F.relu(basicblock, inplace=True)
        basicblock = self.dropout_a(basicblock)  # Áp dụng dropout

        basicblock = self.conv_b(basicblock)
        basicblock = self.bn_b(basicblock)
        basicblock = self.dropout_b(basicblock)  # Áp dụng dropout

        if self.downsample:
            residual = self.downsample(x)

        return F.relu(residual + basicblock, inplace=True)

In [None]:
class CustomModel(nn.Module):
    def __init__(self, input_size=69, hidden_sizes=[32, 64, 128, 256, 512], output_size=5):
        super(CustomModel, self).__init__()
        self.fc1 = quan_Conv1d(input_size, hidden_sizes[0], kernel_size=3, stride=1, padding=1)
        self.bn_1 = nn.BatchNorm1d(hidden_sizes[0])

        self.inplanes = 32
        self.stage_1 = self._make_layer(SEBlock, 32, 16, 1)
        self.stage_2 = self._make_layer(SEBlock, 64, 16, 2)
        self.stage_3 = self._make_layer(SEBlock, 128, 16, 2)
        self.avgpool = nn.AdaptiveAvgPool1d(1)

        self.classifier = CustomBlock(128 * SEBlock.expansion, output_size)

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                #m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                init.kaiming_normal(m.weight)
                m.bias.data.zero_()
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        downsample = None
        if stride == 2 or self.inplanes != planes * SEBlock.expansion:
            downsample = DownsampleA(self.inplanes, planes * SEBlock.expansion, stride) if stride == 2 else None

        layers = []
        layers.append(block(self.inplanes, planes, stride=1, downsample=downsample))
        self.inplanes = planes * SEBlock.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(self.bn_1(x), inplace=True)
        
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

In [None]:
class CustomModel2(nn.Module):
    def __init__(self, input_size=69, hidden_sizes=[32, 64, 128, 100], output_size=5):
        super(CustomModel2, self).__init__()
        self.hidden_sizes = hidden_sizes

        # Define layers
        self.fc1 = nn.Conv1d(input_size, hidden_sizes[0], kernel_size=3, stride=2, padding=1)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        
        self.stage_1 = nn.Conv1d(hidden_sizes[0], hidden_sizes[1], kernel_size=3, stride=2, padding=1)
        self.stage_2 = nn.Conv1d(hidden_sizes[1], hidden_sizes[2], kernel_size=3, stride=2, padding=1)
        self.stage_3 = nn.Conv1d(hidden_sizes[2], hidden_sizes[3], kernel_size=3, stride=2, padding=1)

        # Global Pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)

        # Classifier
        self.classifier = CustomBlock(hidden_sizes[-1], output_size, apply_softmax=True)
        nn.Dropout(0.15)
        
    def forward(self, x):
        # Pass through layers
        x = self.fc1(x)
        x = self.activation(self.pool(x))

        x = self.stage_1(x)
        x = self.activation(self.pool(x))

        x = self.stage_2(x)
        x = self.activation(self.pool(x))

        x = self.stage_3(x)
        x = self.activation(self.pool(x))

        # Global Pooling
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten
        return self.classifier(x)

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, n_classes)
        )
    def forward(self, x):
        return self.layers(x)
# Hàm giúp chọn model dựa trên cấu hình
def build_model(model_name, in_features, n_classes):
    if model_name == 'mlp':
        return MLPClassifier(in_features, n_classes)
    elif model_name == 'custom1':
        return CustomModel1(input_size=in_features, output_size=n_classes)
    elif model_name == 'custom2':
        return CustomModel2(input_size=in_features, output_size=n_classes)
    else:
        raise ValueError('Unknown model: ' + model_name)

print("Các kiến trúc model đã sẵn sàng.")

In [None]:
@torch.no_grad()
def evaluate(model, loader, device='cpu'):
    model.eval()
    total, correct = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total * 100

def train(model, train_loader, test_loader, epochs, lr, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total, correct, loss_sum = 0, 0, 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item() * y.size(0)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        
        train_acc = correct / total * 100
        test_acc = evaluate(model, test_loader, device)
        print(f'[Epoch {epoch+1:03d}] Train Acc: {train_acc:6.2f}% | Test Acc: {test_acc:6.2f}% | Loss: {loss_sum/total:.4f}')
    
    print("\nHoàn tất huấn luyện!")
    return model

print("Các hàm train/evaluate đã sẵn sàng.")

In [None]:
class PureCNN(nn.Module):
    """Pure CNN model without quantization for comparison"""
    def __init__(self, input_size=69, output_size=5):
        super().__init__()

        # Convolutional layers with increasing channels
        self.conv1 = nn.Conv1d(input_size, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm1d(512)
        self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2)

        # Adaptive pooling to handle variable input lengths
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)

        # Fully connected layers
        self.dropout1 = nn.Dropout(0.3)
        self.fc1 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, 128)
        self.dropout3 = nn.Dropout(0.2)
        self.classifier = nn.Linear(128, output_size)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize model weights using Xavier initialization"""
        for module in self.modules():
            if isinstance(module, nn.Conv1d):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, x):
        # Handle input shape: expect [B, 69] or [B, 69, 1]
        if x.dim() == 2:
            x = x.unsqueeze(-1)  # [B, 69] -> [B, 69, 1]

        # If input is [B, 69, 1], transpose to [B, 1, 69] for Conv1d
        if x.size(-1) == 1:
            x = x.transpose(1, 2)  # [B, 69, 1] -> [B, 1, 69]

        # Convolutional blocks
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)

        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)

        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool4(x)

        # Global average pooling
        x = self.adaptive_pool(x)  # [B, 512, 1]
        x = x.view(x.size(0), -1)  # [B, 512]

        # Fully connected layers
        x = self.dropout1(x)
        x = F.relu(self.fc1(x))

        x = self.dropout2(x)
        x = F.relu(self.fc2(x))

        x = self.dropout3(x)
        x = self.classifier(x)
        return x


class EfficientCNN(nn.Module):
    """Lightweight CNN with depthwise separable convolutions"""
    def __init__(self, input_size=69, output_size=5):
        super().__init__()

        # Depthwise separable convolution blocks
        self.conv1 = nn.Conv1d(input_size, input_size, kernel_size=3, stride=1, padding=1, groups=input_size)
        self.pw_conv1 = nn.Conv1d(input_size, 64, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1, groups=64)
        self.pw_conv2 = nn.Conv1d(64, 128, kernel_size=1)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1, groups=128)
        self.pw_conv3 = nn.Conv1d(128, 256, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)

        # Global pooling and classification
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.4)
        self.classifier = nn.Linear(256, output_size)

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize with Xavier uniform"""
        for module in self.modules():
            if isinstance(module, nn.Conv1d):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        # Handle input shape
        if x.dim() == 2:
            x = x.unsqueeze(-1)

        if x.size(-1) == 1:
            x = x.transpose(1, 2)

        # Depthwise separable blocks
        x = F.relu(self.bn1(self.pw_conv1(self.conv1(x))))
        x = self.pool1(x)

        x = F.relu(self.bn2(self.pw_conv2(self.conv2(x))))
        x = self.pool2(x)

        x = F.relu(self.bn3(self.pw_conv3(self.conv3(x))))
        x = self.pool3(x)

        # Global pooling
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)

        x = self.dropout(x)
        return self.classifier(x)


In [None]:

# 1. Tạo thư mục lưu trữ nếu chưa có
os.makedirs(os.path.dirname(args.OUTPUT_PATH), exist_ok=True)

# 2. Chuẩn bị dữ liệu
train_loader, test_loader, in_features, n_classes = build_iotid20_loaders(
    args.DATA_ROOT, 
    source_csv=args.SOURCE_CSV, 
    download_url=args.DOWNLOAD_URL
)

# 3. Xây dựng model
model = build_model(args.MODEL_NAME, in_features, n_classes)
print(f"Đã xây dựng model '{args.MODEL_NAME}'")
print(model)

# 4. Bắt đầu huấn luyện
trained_model = train(
    model=model,
    train_loader=train_loader, 
    test_loader=test_loader, 
    epochs=args.EPOCHS, 
    lr=args.LEARNING_RATE, 
    device=args.DEVICE
)

# 5. Lưu trọng số của model
torch.save(trained_model.state_dict(), args.OUTPUT_PATH)
print(f'\nĐã lưu trọng số vào file: {args.OUTPUT_PATH}')