1. TPA 2. GAP 3. TPA에 lowpass, dwconv 제거한 모델을 비교  

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# -*- coding: utf-8 -*-
"""
TPA Noise Robustness Study: 3-Model Comparison
Models: GAP (Baseline), TPA (Full), TPA-Simple (No Conv)
Noise Types: Gaussian, Temporal Mask, Channel Drift, Channel Drop, Spike
"""

import os, random, math, sys, time, copy, json
import numpy as np
import pandas as pd
from typing import Tuple, List
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# ========================
# 0) Config & Reproducibility
# ========================
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

@dataclass
class Config:
    data_dir: str = "/content/drive/MyDrive/AI_data/UCI_HAR_Dataset/UCI HAR Dataset"
    save_dir: str = "/content/drive/MyDrive/AI_data/noise_robustness_study"

    epochs: int = 100
    batch_size: int = 128
    lr: float = 1e-4
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    label_smoothing: float = 0.05

    # Early stopping
    patience: int = 20
    min_delta: float = 0.0001
    val_split: float = 0.2

    # Training augmentation (mixed noise)
    train_augment_prob: float = 0.7
    max_train_noise: float = 0.4

    d_model: int = 128

    # TPA hyperparameters
    tpa_num_prototypes: int = 16
    tpa_seg_kernel: int = 9
    tpa_heads: int = 4
    tpa_dropout: float = 0.1
    tpa_temperature: float = 0.07
    tpa_topk_ratio: float = 0.25

    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2

# ========================
# 1) Noise Augmentation Functions
# ========================

def add_gaussian_noise(X, noise_level):
    """Add Gaussian noise"""
    if noise_level == 0:
        return X
    noise = np.random.normal(0, noise_level, X.shape).astype(np.float32)
    return X + noise

def add_temporal_mask(X, mask_ratio=0.3):
    """Add temporal masking"""
    X_aug = X.copy()
    B, C, T = X.shape
    for i in range(B):
        if np.random.rand() < 0.5:
            mask_len = int(T * mask_ratio)
            start = np.random.randint(0, T - mask_len)
            X_aug[i, :, start:start+mask_len] = 0
    return X_aug

def add_channel_drift(X, drift_std=0.2):
    """Channel drift"""
    X_aug = X.copy()
    B, C, T = X.shape
    for i in range(B):
        for c in range(C):
            if np.random.rand() < 0.3:
                drift = np.random.normal(0, drift_std)
                X_aug[i, c, :] += drift
    return X_aug

def add_channel_drop(X, drop_prob=0.2):
    """Channel drop"""
    X_aug = X.copy()
    B, C, T = X.shape
    for i in range(B):
        drop_mask = np.random.rand(C) < drop_prob
        X_aug[i, drop_mask, :] = 0
    return X_aug

def add_spike_noise(X, spike_prob=0.05, spike_magnitude=2.0):
    """Spike noise"""
    X_aug = X.copy()
    spike_mask = np.random.rand(*X.shape) < spike_prob
    spikes = np.random.randn(*X.shape) * spike_magnitude
    X_aug[spike_mask] += spikes[spike_mask]
    return X_aug

def apply_mixed_augmentation(X, max_noise=0.4):
    """Apply random combination of augmentations"""
    aug_type = np.random.choice([
        'gaussian', 'temporal', 'drift', 'drop', 'spike', 'mixed'
    ])

    if aug_type == 'gaussian':
        noise_level = np.random.uniform(0, max_noise)
        return add_gaussian_noise(X, noise_level)
    elif aug_type == 'temporal':
        mask_ratio = np.random.uniform(0.1, 0.4)
        return add_temporal_mask(X, mask_ratio)
    elif aug_type == 'drift':
        drift_std = np.random.uniform(0.1, 0.3)
        return add_channel_drift(X, drift_std)
    elif aug_type == 'drop':
        drop_prob = np.random.uniform(0.1, 0.3)
        return add_channel_drop(X, drop_prob)
    elif aug_type == 'spike':
        spike_prob = np.random.uniform(0.02, 0.08)
        return add_spike_noise(X, spike_prob)
    else:  # mixed
        X_aug = X.copy()
        num_augs = np.random.randint(2, 4)
        augs = np.random.choice(
            ['gaussian', 'temporal', 'drift', 'spike'],
            size=num_augs, replace=False
        )
        for aug in augs:
            if aug == 'gaussian':
                X_aug = add_gaussian_noise(X_aug, np.random.uniform(0, max_noise*0.5))
            elif aug == 'temporal':
                X_aug = add_temporal_mask(X_aug, np.random.uniform(0.1, 0.2))
            elif aug == 'drift':
                X_aug = add_channel_drift(X_aug, np.random.uniform(0.05, 0.15))
            elif aug == 'spike':
                X_aug = add_spike_noise(X_aug, np.random.uniform(0.02, 0.05))
        return X_aug

# ========================
# 2) UCI-HAR Data Loader
# ========================
_RAW_CHANNELS = [
    ("Inertial Signals/total_acc_x_", "txt"), ("Inertial Signals/total_acc_y_", "txt"),
    ("Inertial Signals/total_acc_z_", "txt"), ("Inertial Signals/body_acc_x_", "txt"),
    ("Inertial Signals/body_acc_y_", "txt"), ("Inertial Signals/body_acc_z_", "txt"),
    ("Inertial Signals/body_gyro_x_", "txt"), ("Inertial Signals/body_gyro_y_", "txt"),
    ("Inertial Signals/body_gyro_z_", "txt"),
]

def _load_split_raw(root: str, split: str) -> Tuple[np.ndarray, np.ndarray]:
    assert split in ("train", "test")
    X_list = [np.loadtxt(os.path.join(root, split, p + split + "." + e))[..., None] for p, e in _RAW_CHANNELS]
    X = np.concatenate(X_list, axis=-1).transpose(0, 2, 1)
    y = np.loadtxt(os.path.join(root, split, f"y_{split}.txt")).astype(int)
    return X, y

class UCIHARInertial(Dataset):
    def __init__(self, root: str, split: str, mean=None, std=None,
                 preloaded_data: Tuple[np.ndarray, np.ndarray] | None = None,
                 indices: np.ndarray | None = None,
                 augment: bool = False, max_noise: float = 0.4, augment_prob: float = 0.7):
        super().__init__()

        if preloaded_data is not None:
            X, y = preloaded_data
        else:
            X, y = _load_split_raw(root, split)

        if indices is not None:
            X = X[indices]
            y = y[indices]

        self.X = X.astype(np.float32)
        self.y = (y - 1).astype(np.int64) if y.min() >= 1 else y.astype(np.int64)

        if mean is not None and std is not None:
            self.mean, self.std = mean, std
            if preloaded_data is None:
                self.X = (self.X - self.mean) / self.std
        else:
            self.mean = self.X.mean(axis=(0,2), keepdims=True).astype(np.float32)
            self.std = (self.X.std(axis=(0,2), keepdims=True) + 1e-6).astype(np.float32)
            self.X = ((self.X - self.mean) / self.std).astype(np.float32)

        self.augment = augment
        self.max_noise = max_noise
        self.augment_prob = augment_prob

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx].copy()
        y = self.y[idx]

        if self.augment and np.random.rand() < self.augment_prob:
            x = apply_mixed_augmentation(x[np.newaxis], self.max_noise)[0]

        return torch.from_numpy(x).float(), torch.tensor(y, dtype=torch.long)

# ========================
# 3) TPA Module
# ========================
class ProductionTPA(nn.Module):
    """Temporal Prototype Attention (Full version with conv layers)"""

    def __init__(self, dim, num_prototypes=16, seg_kernel=9, heads=4, dropout=0.1,
                 temperature=0.07, topk_ratio=0.25):
        super().__init__()
        assert dim % heads == 0

        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.num_prototypes = num_prototypes
        self.temperature = temperature
        self.topk_ratio = topk_ratio

        self.proto = nn.Parameter(torch.randn(num_prototypes, dim) * 0.02)

        pad = (seg_kernel - 1) // 2
        self.lowpass = nn.Conv1d(dim, dim, kernel_size=5, padding=2, groups=dim, bias=False)
        self.dw = nn.Conv1d(dim, dim, kernel_size=seg_kernel, padding=pad, groups=dim, bias=False)
        self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=False)

        self.pre_norm = nn.LayerNorm(dim)

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)

        self.fuse = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim)
        )

        self.conf_head = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim // 4, 1)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_confidence=False):
        """
        x: [B, T, D]
        """
        B, T, D = x.shape
        P = self.num_prototypes

        x_filtered = self.lowpass(x.transpose(1, 2)).transpose(1, 2)
        xloc = self.pw(self.dw(x_filtered.transpose(1, 2))).transpose(1, 2)
        xloc = self.pre_norm(xloc) + x

        K = self.k_proj(xloc)
        V = self.v_proj(xloc)

        Qp = self.q_proj(self.proto).unsqueeze(0).expand(B, -1, -1)

        def split_heads(t, length):
            return t.view(B, length, self.heads, self.head_dim).transpose(1, 2)

        Qh = split_heads(Qp, P)
        Kh = split_heads(K, T)
        Vh = split_heads(V, T)

        Qh = F.normalize(Qh, dim=-1)
        Kh = F.normalize(Kh, dim=-1)

        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / self.temperature
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        attn = self.dropout(attn)

        proto_tokens = torch.matmul(attn, Vh)
        proto_tokens = proto_tokens.transpose(1, 2).contiguous().view(B, P, D)

        topk = max(1, int(P * self.topk_ratio))
        vals, _ = torch.topk(proto_tokens, k=topk, dim=1)
        z_tpa = vals.mean(dim=1)

        z_tpa = self.fuse(z_tpa)
        z_tpa = self.out_proj(z_tpa)

        z_gap = x.mean(dim=1)

        confidence = torch.sigmoid(self.conf_head(z_tpa))
        z = confidence * z_tpa + (1 - confidence) * z_gap

        if return_confidence:
            return z, confidence
        return z


class SimpleTPA(nn.Module):
    """Simplified TPA without lowpass, depthwise, pointwise convolutions"""

    def __init__(self, dim, num_prototypes=16, heads=4, dropout=0.1,
                 temperature=0.07, topk_ratio=0.25):
        super().__init__()
        assert dim % heads == 0

        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.num_prototypes = num_prototypes
        self.temperature = temperature
        self.topk_ratio = topk_ratio

        self.proto = nn.Parameter(torch.randn(num_prototypes, dim) * 0.02)

        self.pre_norm = nn.LayerNorm(dim)

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)

        self.fuse = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim)
        )

        self.conf_head = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim // 4, 1)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_confidence=False):
        """
        x: [B, T, D]
        """
        B, T, D = x.shape
        P = self.num_prototypes

        # Simply normalize without conv layers
        xloc = self.pre_norm(x)

        K = self.k_proj(xloc)
        V = self.v_proj(xloc)

        Qp = self.q_proj(self.proto).unsqueeze(0).expand(B, -1, -1)

        def split_heads(t, length):
            return t.view(B, length, self.heads, self.head_dim).transpose(1, 2)

        Qh = split_heads(Qp, P)
        Kh = split_heads(K, T)
        Vh = split_heads(V, T)

        Qh = F.normalize(Qh, dim=-1)
        Kh = F.normalize(Kh, dim=-1)

        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / self.temperature
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        attn = self.dropout(attn)

        proto_tokens = torch.matmul(attn, Vh)
        proto_tokens = proto_tokens.transpose(1, 2).contiguous().view(B, P, D)

        topk = max(1, int(P * self.topk_ratio))
        vals, _ = torch.topk(proto_tokens, k=topk, dim=1)
        z_tpa = vals.mean(dim=1)

        z_tpa = self.fuse(z_tpa)
        z_tpa = self.out_proj(z_tpa)

        z_gap = x.mean(dim=1)

        confidence = torch.sigmoid(self.conf_head(z_tpa))
        z = confidence * z_tpa + (1 - confidence) * z_gap

        if return_confidence:
            return z, confidence
        return z

# ========================
# 4) Model Definitions
# ========================
class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k, s=1, p=None, g=1):
        super().__init__()
        self.c = nn.Conv1d(c_in, c_out, k, s, k//2 if p is None else p, groups=g, bias=False)
        self.bn = nn.BatchNorm1d(c_out)
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.bn(self.c(x)))

class MultiPathCNN(nn.Module):
    def __init__(self, in_ch=9, d_model=128, branches=(3,5,9,15), stride=2):
        super().__init__()
        h = d_model // 2
        self.pre = ConvBNAct(in_ch, h, 1)
        self.branches = nn.ModuleList([nn.Sequential(ConvBNAct(h, h, k, stride, g=h), ConvBNAct(h, h, 1)) for k in branches])
        self.post = ConvBNAct(len(branches)*h, d_model, 1)
        self.stride = stride

    def forward(self, x):
        return self.post(torch.cat([b(self.pre(x)) for b in self.branches], dim=1))

class SimpleGAPHead(nn.Module):
    """Baseline: Global Average Pooling"""
    def __init__(self, d_model: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, Fmap):
        features = Fmap.transpose(1, 2)
        pooled = features.mean(dim=1)
        logits = self.fc(pooled)
        return logits

class TPAHead(nn.Module):
    """TPA: Temporal Prototype Attention"""
    def __init__(self, d_model: int, num_classes: int,
                 num_prototypes: int = 16, seg_kernel: int = 9,
                 heads: int = 4, dropout: float = 0.1,
                 temperature: float = 0.07, topk_ratio: float = 0.25,
                 use_simple: bool = False):
        super().__init__()

        if use_simple:
            self.tpa = SimpleTPA(
                dim=d_model,
                num_prototypes=num_prototypes,
                heads=heads,
                dropout=dropout,
                temperature=temperature,
                topk_ratio=topk_ratio
            )
        else:
            self.tpa = ProductionTPA(
                dim=d_model,
                num_prototypes=num_prototypes,
                seg_kernel=seg_kernel,
                heads=heads,
                dropout=dropout,
                temperature=temperature,
                topk_ratio=topk_ratio
            )

        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, Fmap):
        """
        Fmap: [B, D, T]
        """
        features = Fmap.transpose(1, 2)
        z = self.tpa(features, return_confidence=False)
        logits = self.classifier(z)
        return logits

class HAR_Model(nn.Module):
    def __init__(self, d_model=128, num_classes=6, model_type='gap', tpa_config=None, use_simple_tpa=False):
        super().__init__()
        self.backbone = MultiPathCNN(d_model=d_model)
        self.model_type = model_type

        if model_type == 'gap':
            self.head = SimpleGAPHead(d_model=d_model, num_classes=num_classes)
        else:  # tpa
            self.head = TPAHead(
                d_model=d_model,
                num_classes=num_classes,
                num_prototypes=tpa_config.get('num_prototypes', 16),
                seg_kernel=tpa_config.get('seg_kernel', 9),
                heads=tpa_config.get('heads', 4),
                dropout=tpa_config.get('dropout', 0.1),
                temperature=tpa_config.get('temperature', 0.07),
                topk_ratio=tpa_config.get('topk_ratio', 0.25),
                use_simple=use_simple_tpa
            )

    def forward(self, x):
        fmap = self.backbone(x)
        return self.head(fmap)

# ========================
# 5) Train / Eval
# ========================
def train_one_epoch(model, loader, opt, cfg: Config):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0

    for x, y in loader:
        x = x.to(cfg.device).float()
        y = y.to(cfg.device)

        opt.zero_grad(set_to_none=True)
        logits = model(x)

        loss = F.cross_entropy(logits, y, label_smoothing=cfg.label_smoothing)

        if torch.isnan(loss):
            print("  Warning: NaN loss detected, skipping batch")
            continue

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()

        with torch.no_grad():
            pred = logits.argmax(dim=-1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)

    return {
        "loss": loss_sum / total if total > 0 else 0,
        "acc": correct / total if total > 0 else 0,
    }

@torch.no_grad()
def evaluate(model, loader, cfg: Config):
    model.eval()
    ys, ps = [], []

    for x, y in loader:
        x = x.to(cfg.device)
        y = y.to(cfg.device)

        logits = model(x)
        ps.append(logits.argmax(dim=-1).cpu().numpy())
        ys.append(y.cpu().numpy())

    y_true, y_pred = np.concatenate(ys), np.concatenate(ps)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')

    return acc, f1

# ========================
# 6) Results Display Functions
# ========================

def print_model_results(model_name, results_list, clean_acc):
    """Print detailed results for a single model"""
    print(f"\n{'='*100}")
    print(f"   {model_name} - DETAILED RESULTS")
    print(f"{'='*100}\n")

    print(f"{'Noise Type':<25} | {'Accuracy':>10} | {'F1-Score':>10} | {'Drop':>10} | {'Retention':>12}")
    print("-" * 100)

    total_drop = 0
    total_retention = 0
    noise_count = 0

    for result in results_list:
        noise_type = result['noise_type']
        acc = result['accuracy']
        f1 = result['f1']

        if noise_type == 'Clean':
            drop = 0.0
            retention = 100.0
        else:
            drop = clean_acc - acc
            retention = (acc / clean_acc) * 100
            total_drop += drop
            total_retention += retention
            noise_count += 1

        print(f"{noise_type:<25} | {acc:9.2f}% | {f1:9.3f} | {drop:9.2f}% | {retention:11.2f}%")

    # Print averages
    if noise_count > 0:
        avg_drop = total_drop / noise_count
        avg_retention = total_retention / noise_count

        print("-" * 100)
        print(f"{'AVERAGE (excl. Clean)':<25} | {'-':>10} | {'-':>10} | {avg_drop:9.2f}% | {avg_retention:11.2f}%")

    print()

def print_comparison_table(gap_results, tpa_results, tpa_simple_results):
    """Print side-by-side comparison of all three models"""
    print(f"\n{'='*150}")
    print(f"   COMPARATIVE ANALYSIS: GAP vs TPA vs TPA-Simple")
    print(f"{'='*150}\n")

    # Get clean accuracies
    gap_clean = next(r['accuracy'] for r in gap_results if r['noise_type'] == 'Clean')
    tpa_clean = next(r['accuracy'] for r in tpa_results if r['noise_type'] == 'Clean')
    tpa_simple_clean = next(r['accuracy'] for r in tpa_simple_results if r['noise_type'] == 'Clean')

    print(f"{'Noise Type':<25} | {'GAP Acc':>9} | {'GAP Drop':>9} | {'TPA Acc':>9} | {'TPA Drop':>9} | "
          f"{'Simple Acc':>9} | {'Simple Drop':>9} | {'Best Model':<12}")
    print("-" * 150)

    gap_total_drop = 0
    tpa_total_drop = 0
    tpa_simple_total_drop = 0
    noise_count = 0

    for gap_r, tpa_r, tpa_simple_r in zip(gap_results, tpa_results, tpa_simple_results):
        noise_type = gap_r['noise_type']
        gap_acc = gap_r['accuracy']
        tpa_acc = tpa_r['accuracy']
        tpa_simple_acc = tpa_simple_r['accuracy']

        if noise_type == 'Clean':
            gap_drop = 0.0
            tpa_drop = 0.0
            tpa_simple_drop = 0.0
            best_model = "-"
        else:
            gap_drop = gap_clean - gap_acc
            tpa_drop = tpa_clean - tpa_acc
            tpa_simple_drop = tpa_simple_clean - tpa_simple_acc

            gap_total_drop += gap_drop
            tpa_total_drop += tpa_drop
            tpa_simple_total_drop += tpa_simple_drop
            noise_count += 1

            # Find best (lowest drop)
            drops = {'GAP': gap_drop, 'TPA': tpa_drop, 'TPA-Simple': tpa_simple_drop}
            best_model = min(drops, key=drops.get)

        print(f"{noise_type:<25} | {gap_acc:8.2f}% | {gap_drop:8.2f}% | {tpa_acc:8.2f}% | {tpa_drop:8.2f}% | "
              f"{tpa_simple_acc:8.2f}% | {tpa_simple_drop:8.2f}% | {best_model:<12}")

    # Print averages
    if noise_count > 0:
        gap_avg_drop = gap_total_drop / noise_count
        tpa_avg_drop = tpa_total_drop / noise_count
        tpa_simple_avg_drop = tpa_simple_total_drop / noise_count

        print("-" * 150)
        print(f"{'AVERAGE (excl. Clean)':<25} | {'-':>9} | {gap_avg_drop:8.2f}% | {'-':>9} | {tpa_avg_drop:8.2f}% | "
              f"{'-':>9} | {tpa_simple_avg_drop:8.2f}% | {'-':<12}")

    print()

    # Summary statistics
    print(f"\n{'='*150}")
    print(f"   SUMMARY STATISTICS")
    print(f"{'='*150}\n")

    gap_avg_retention = ((gap_clean - gap_avg_drop) / gap_clean) * 100
    tpa_avg_retention = ((tpa_clean - tpa_avg_drop) / tpa_clean) * 100
    tpa_simple_avg_retention = ((tpa_simple_clean - tpa_simple_avg_drop) / tpa_simple_clean) * 100

    print(f"{'Metric':<30} | {'GAP Baseline':>15} | {'TPA (Full)':>15} | {'TPA-Simple':>15} | {'Best Model':<15}")
    print("-" * 150)

    # Clean Accuracy
    clean_best = max([('GAP', gap_clean), ('TPA', tpa_clean), ('TPA-Simple', tpa_simple_clean)], key=lambda x: x[1])[0]
    print(f"{'Clean Accuracy':<30} | {gap_clean:14.2f}% | {tpa_clean:14.2f}% | {tpa_simple_clean:14.2f}% | {clean_best:<15}")

    # Average Drop
    drop_best = min([('GAP', gap_avg_drop), ('TPA', tpa_avg_drop), ('TPA-Simple', tpa_simple_avg_drop)], key=lambda x: x[1])[0]
    print(f"{'Average Drop':<30} | {gap_avg_drop:14.2f}% | {tpa_avg_drop:14.2f}% | {tpa_simple_avg_drop:14.2f}% | {drop_best:<15}")

    # Average Retention
    retention_best = max([('GAP', gap_avg_retention), ('TPA', tpa_avg_retention), ('TPA-Simple', tpa_simple_avg_retention)], key=lambda x: x[1])[0]
    print(f"{'Average Retention':<30} | {gap_avg_retention:14.2f}% | {tpa_avg_retention:14.2f}% | {tpa_simple_avg_retention:14.2f}% | {retention_best:<15}")

    # Robustness score
    gap_robustness = gap_clean * 0.3 + gap_avg_retention * 0.7
    tpa_robustness = tpa_clean * 0.3 + tpa_avg_retention * 0.7
    tpa_simple_robustness = tpa_simple_clean * 0.3 + tpa_simple_avg_retention * 0.7

    robustness_best = max([('GAP', gap_robustness), ('TPA', tpa_robustness), ('TPA-Simple', tpa_simple_robustness)], key=lambda x: x[1])[0]
    print(f"{'Robustness Score':<30} | {gap_robustness:14.2f} | {tpa_robustness:14.2f} | {tpa_simple_robustness:14.2f} | {robustness_best:<15}")

    print()

    # Improvement analysis
    print(f"\n{'='*150}")
    print(f"   IMPROVEMENT OVER GAP BASELINE")
    print(f"{'='*150}\n")

    print(f"{'Metric':<30} | {'TPA (Full)':>20} | {'TPA-Simple':>20}")
    print("-" * 150)
    print(f"{'Clean Accuracy':<30} | {tpa_clean - gap_clean:+19.2f}% | {tpa_simple_clean - gap_clean:+19.2f}%")
    print(f"{'Average Drop Reduction':<30} | {gap_avg_drop - tpa_avg_drop:+19.2f}% | {gap_avg_drop - tpa_simple_avg_drop:+19.2f}%")
    print(f"{'Average Retention Gain':<30} | {tpa_avg_retention - gap_avg_retention:+19.2f}% | {tpa_simple_avg_retention - gap_avg_retention:+19.2f}%")
    print(f"{'Robustness Score Gain':<30} | {tpa_robustness - gap_robustness:+19.2f} | {tpa_simple_robustness - gap_robustness:+19.2f}")
    print()

# ========================
# 7) Noise Robustness Study
# ========================
def run_noise_robustness_study(cfg: Config):
    os.makedirs(cfg.save_dir, exist_ok=True)

    # Load data
    print(f"\n{'='*70}")
    print("   DATA PREPARATION")
    print(f"{'='*70}")

    X_full, y_full = _load_split_raw(cfg.data_dir, "train")
    mean = X_full.mean(axis=(0,2), keepdims=True)
    std = X_full.std(axis=(0,2), keepdims=True) + 1e-6
    X_full = ((X_full - mean) / std).astype(np.float32)

    indices = np.arange(len(X_full))
    train_indices, val_indices = train_test_split(
        indices,
        test_size=cfg.val_split,
        random_state=SEED,
        stratify=y_full
    )

    print(f"Total samples: {len(X_full)}")
    print(f"  → Train: {len(train_indices)} ({(1-cfg.val_split)*100:.0f}%)")
    print(f"  → Val:   {len(val_indices)} ({cfg.val_split*100:.0f}%)")

    # Create datasets
    train_set = UCIHARInertial(
        cfg.data_dir, "train",
        mean=mean, std=std,
        preloaded_data=(X_full, y_full),
        indices=train_indices,
        augment=True,
        max_noise=cfg.max_train_noise,
        augment_prob=cfg.train_augment_prob
    )

    val_set = UCIHARInertial(
        cfg.data_dir, "train",
        mean=mean, std=std,
        preloaded_data=(X_full, y_full),
        indices=val_indices,
        augment=False
    )

    test_set_orig = UCIHARInertial(cfg.data_dir, "test", mean=mean, std=std, augment=False)

    val_loader = DataLoader(val_set, cfg.batch_size, num_workers=cfg.num_workers)

    # Define noise configurations for testing
    noise_configs = [
        {'name': 'Clean', 'type': 'none', 'level': 0.0},
        {'name': '20% Gaussian', 'type': 'gaussian', 'level': 0.2},
        {'name': '40% Gaussian', 'type': 'gaussian', 'level': 0.4},
        {'name': '30% Temporal Mask', 'type': 'temporal', 'level': 0.3},
        {'name': '20% Channel Drift', 'type': 'drift', 'level': 0.2},
        {'name': '20% Channel Drop', 'type': 'drop', 'level': 0.2},
        {'name': '5% Spike Noise', 'type': 'spike', 'level': 0.05},
    ]

    # Define 3 model configurations
    model_configs = [
        {"name": "GAP_Baseline", "model_type": "gap", "use_simple": False},
        {"name": "TPA", "model_type": "tpa", "use_simple": False},
        {"name": "TPA_Simple", "model_type": "tpa", "use_simple": True},
    ]

    print(f"\n{'='*70}")
    print("    NOISE ROBUSTNESS STUDY: 3-Model Comparison")
    print(f"{'='*70}")
    print(f"   Models: GAP (Baseline), TPA (Full), TPA-Simple (No Conv)")
    print(f"   Training: Mixed noise augmentation (prob={cfg.train_augment_prob}, max={cfg.max_train_noise})")
    print(f"   Testing: 7 noise scenarios")
    print(f"{'='*70}\n")

    # Store results for each model
    gap_results = []
    tpa_results = []
    tpa_simple_results = []

    # Train and evaluate each model
    for model_cfg in model_configs:
        print(f"\n{'='*70}")
        print(f"   MODEL: {model_cfg['name']}")
        print(f"{'='*70}")

        random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

        tpa_config = {
            'num_prototypes': cfg.tpa_num_prototypes,
            'seg_kernel': cfg.tpa_seg_kernel,
            'heads': cfg.tpa_heads,
            'dropout': cfg.tpa_dropout,
            'temperature': cfg.tpa_temperature,
            'topk_ratio': cfg.tpa_topk_ratio
        }

        model = HAR_Model(
            d_model=cfg.d_model,
            model_type=model_cfg["model_type"],
            tpa_config=tpa_config,
            use_simple_tpa=model_cfg.get("use_simple", False)
        ).to(cfg.device).float()

        g = torch.Generator(device='cpu').manual_seed(SEED)
        train_loader = DataLoader(train_set, cfg.batch_size, shuffle=True,
                                 num_workers=cfg.num_workers, generator=g)

        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        best_acc, best_wts = 0.0, None
        patience_counter = 0
        best_epoch = 0

        print(f"\nTraining {model_cfg['name']} for up to {cfg.epochs} epochs (patience={cfg.patience})...")

        for epoch in range(1, cfg.epochs + 1):
            stats = train_one_epoch(model, train_loader, opt, cfg)
            val_acc, val_f1 = evaluate(model, val_loader, cfg)

            improved = False
            if val_acc > best_acc + cfg.min_delta:
                best_acc = val_acc
                best_wts = copy.deepcopy(model.state_dict())
                patience_counter = 0
                best_epoch = epoch
                improved = True
            else:
                patience_counter += 1

            log_str = f"[{epoch:02d}/{cfg.epochs}] Train L:{stats['loss']:.4f} A:{stats['acc']:.4f}"
            log_str += f" | Val A:{val_acc:.4f} F1:{val_f1:.4f}"
            if improved:
                log_str += " ✓"

            if epoch % 10 == 0 or epoch == 1:
                print(log_str)

            if patience_counter >= cfg.patience:
                print(f"\n⚠ Early stopping triggered at epoch {epoch}")
                print(f"  Best validation acc: {best_acc:.4f} (epoch {best_epoch})")
                break

        if best_wts:
            model_path = os.path.join(cfg.save_dir, f"model_{model_cfg['name']}.pth")
            torch.save(best_wts, model_path)
            model.load_state_dict(best_wts)
            print(f"\n✓ Best Val Acc: {best_acc:.4f} (epoch {best_epoch})")

        # Test on all noise scenarios
        print(f"\n   Testing on {len(noise_configs)} noise scenarios...")

        model_results = []

        for noise_cfg in noise_configs:
            X_test = test_set_orig.X.copy()

            # Apply noise
            if noise_cfg['type'] == 'gaussian':
                X_test = add_gaussian_noise(X_test, noise_cfg['level'])
            elif noise_cfg['type'] == 'temporal':
                X_test = add_temporal_mask(X_test, noise_cfg['level'])
            elif noise_cfg['type'] == 'drift':
                X_test = add_channel_drift(X_test, noise_cfg['level'])
            elif noise_cfg['type'] == 'drop':
                X_test = add_channel_drop(X_test, noise_cfg['level'])
            elif noise_cfg['type'] == 'spike':
                X_test = add_spike_noise(X_test, noise_cfg['level'])

            # Create test dataset
            test_ds_noisy = UCIHARInertial(
                cfg.data_dir, "test",
                mean=mean, std=std,
                preloaded_data=(X_test, test_set_orig.y),
                augment=False
            )
            test_loader_noisy = DataLoader(test_ds_noisy, cfg.batch_size,
                                          num_workers=cfg.num_workers)

            acc, f1 = evaluate(model, test_loader_noisy, cfg)

            model_results.append({
                'noise_type': noise_cfg['name'],
                'accuracy': acc * 100,
                'f1': f1
            })

        # Store results
        if model_cfg['name'] == 'GAP_Baseline':
            gap_results = model_results
        elif model_cfg['name'] == 'TPA':
            tpa_results = model_results
        else:
            tpa_simple_results = model_results

    # ========================
    # Display Results
    # ========================

    # Get clean accuracies
    gap_clean = next(r['accuracy'] for r in gap_results if r['noise_type'] == 'Clean')
    tpa_clean = next(r['accuracy'] for r in tpa_results if r['noise_type'] == 'Clean')
    tpa_simple_clean = next(r['accuracy'] for r in tpa_simple_results if r['noise_type'] == 'Clean')

    # Print individual model results
    print_model_results("GAP BASELINE", gap_results, gap_clean)
    print_model_results("TPA (Full)", tpa_results, tpa_clean)
    print_model_results("TPA-Simple (No Conv)", tpa_simple_results, tpa_simple_clean)

    # Print comparison table
    print_comparison_table(gap_results, tpa_results, tpa_simple_results)

    # Save results to CSV
    all_results = []
    for gap_r, tpa_r, tpa_simple_r in zip(gap_results, tpa_results, tpa_simple_results):
        noise_type = gap_r['noise_type']

        # GAP metrics
        gap_acc = gap_r['accuracy']
        gap_f1 = gap_r['f1']
        gap_drop = 0.0 if noise_type == 'Clean' else gap_clean - gap_acc
        gap_retention = 100.0 if noise_type == 'Clean' else (gap_acc / gap_clean) * 100

        # TPA metrics
        tpa_acc = tpa_r['accuracy']
        tpa_f1 = tpa_r['f1']
        tpa_drop = 0.0 if noise_type == 'Clean' else tpa_clean - tpa_acc
        tpa_retention = 100.0 if noise_type == 'Clean' else (tpa_acc / tpa_clean) * 100

        # TPA-Simple metrics
        tpa_simple_acc = tpa_simple_r['accuracy']
        tpa_simple_f1 = tpa_simple_r['f1']
        tpa_simple_drop = 0.0 if noise_type == 'Clean' else tpa_simple_clean - tpa_simple_acc
        tpa_simple_retention = 100.0 if noise_type == 'Clean' else (tpa_simple_acc / tpa_simple_clean) * 100

        all_results.append({
            'Noise Type': noise_type,
            'GAP Accuracy (%)': gap_acc,
            'GAP F1-Score': gap_f1,
            'GAP Drop (%)': gap_drop,
            'GAP Retention (%)': gap_retention,
            'TPA Accuracy (%)': tpa_acc,
            'TPA F1-Score': tpa_f1,
            'TPA Drop (%)': tpa_drop,
            'TPA Retention (%)': tpa_retention,
            'TPA-Simple Accuracy (%)': tpa_simple_acc,
            'TPA-Simple F1-Score': tpa_simple_f1,
            'TPA-Simple Drop (%)': tpa_simple_drop,
            'TPA-Simple Retention (%)': tpa_simple_retention
        })

    df_results = pd.DataFrame(all_results)
    csv_path = os.path.join(cfg.save_dir, 'three_models_comparison.csv')
    df_results.to_csv(csv_path, index=False)
    print(f"✓ Results saved to: {csv_path}\n")

    return df_results

# ========================
# 8) Main Execution
# ========================
if __name__ == "__main__":
    config = Config()
    config.epochs = 100
    config.lr = 1e-4

    # Training augmentation settings
    config.train_augment_prob = 0.7
    config.max_train_noise = 0.4

    # Early stopping
    config.patience = 20
    config.min_delta = 0.0001
    config.val_split = 0.2

    # TPA hyperparameters
    config.tpa_num_prototypes = 16
    config.tpa_seg_kernel = 9
    config.tpa_heads = 4
    config.tpa_dropout = 0.1
    config.tpa_temperature = 0.07
    config.tpa_topk_ratio = 0.25

    print("\n" + "="*70)
    print("    NOISE ROBUSTNESS STUDY: 3-Model Comparison")
    print("="*70)
    print(f"Device: {config.device}")
    print(f"Epochs: {config.epochs}")
    print(f"Learning Rate: {config.lr}")
    print(f"Training Augmentation: prob={config.train_augment_prob}, max_noise={config.max_train_noise}")
    print(f"\n3 Models to Compare:")
    print(f"  1) GAP Baseline:  Standard Global Average Pooling")
    print(f"  2) TPA (Full):    Temporal Prototype Attention with Conv layers")
    print(f"  3) TPA-Simple:    TPA without lowpass/depthwise/pointwise conv")
    print(f"\n7 Noise Scenarios:")
    print(f"  - Clean (no noise)")
    print(f"  - Gaussian noise (20%, 40%)")
    print(f"  - Temporal mask (30%)")
    print(f"  - Channel drift (20%)")
    print(f"  - Channel drop (20%)")
    print(f"  - Spike noise (5%)")
    print(f"\nTPA Configuration:")
    print(f"  Prototypes: {config.tpa_num_prototypes}")
    print(f"  Heads: {config.tpa_heads}")
    print(f"  Temperature: {config.tpa_temperature}")
    print(f"  TopK Ratio: {config.tpa_topk_ratio}")
    print("="*70 + "\n")

    # Run study
    df_results = run_noise_robustness_study(config)

    print("\n" + "="*70)
    print("STUDY COMPLETED!")
    print("="*70)
    print("\nGenerated files:")
    print(f"  1. {os.path.join(config.save_dir, 'three_models_comparison.csv')}")
    print(f"  2. model_GAP_Baseline.pth")
    print(f"  3. model_TPA.pth")
    print(f"  4. model_TPA_Simple.pth")
    print("="*70 + "\n")


    NOISE ROBUSTNESS STUDY: 3-Model Comparison
Device: cuda
Epochs: 100
Learning Rate: 0.0001
Training Augmentation: prob=0.7, max_noise=0.4

3 Models to Compare:
  1) GAP Baseline:  Standard Global Average Pooling
  2) TPA (Full):    Temporal Prototype Attention with Conv layers
  3) TPA-Simple:    TPA without lowpass/depthwise/pointwise conv

7 Noise Scenarios:
  - Clean (no noise)
  - Gaussian noise (20%, 40%)
  - Temporal mask (30%)
  - Channel drift (20%)
  - Channel drop (20%)
  - Spike noise (5%)

TPA Configuration:
  Prototypes: 16
  Heads: 4
  Temperature: 0.07
  TopK Ratio: 0.25


   DATA PREPARATION
Total samples: 7352
  → Train: 5881 (80%)
  → Val:   1471 (20%)

    NOISE ROBUSTNESS STUDY: 3-Model Comparison
   Models: GAP (Baseline), TPA (Full), TPA-Simple (No Conv)
   Training: Mixed noise augmentation (prob=0.7, max=0.4)
   Testing: 7 noise scenarios


   MODEL: GAP_Baseline

Training GAP_Baseline for up to 100 epochs (patience=20)...
[01/100] Train L:1.4717 A:0.5215 | 