In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# -*- coding: utf-8 -*-
"""
TPA Noise Robustness Study: 3-Model Comparison with Drop-based Analysis
Models: GAP (Baseline), TPA, TPA+Mask
Noise Types: Gaussian, Temporal Mask, Channel Drift, Channel Drop, Spike
"""

import os, random, math, sys, time, copy, json
import numpy as np
import pandas as pd
from typing import Tuple, List
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# ========================
# 0) Config & Reproducibility
# ========================
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

@dataclass
class Config:
    data_dir: str = "/content/drive/MyDrive/AI_data/UCI_HAR_Dataset/UCI HAR Dataset"
    save_dir: str = "/content/drive/MyDrive/AI_data/noise_robustness_study"

    epochs: int = 100
    batch_size: int = 128
    lr: float = 1e-4
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    label_smoothing: float = 0.05

    # Early stopping
    patience: int = 20
    min_delta: float = 0.0001
    val_split: float = 0.2

    # Training augmentation (mixed noise)
    train_augment_prob: float = 0.7
    max_train_noise: float = 0.4

    d_model: int = 128

    # TPA hyperparameters
    tpa_num_prototypes: int = 16
    tpa_seg_kernel: int = 9
    tpa_heads: int = 4
    tpa_dropout: float = 0.1
    tpa_temperature: float = 0.07
    tpa_topk_ratio: float = 0.25

    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2

# ========================
# 1) Noise Augmentation Functions with Masks
# ========================

def add_gaussian_noise(X, noise_level):
    """Add Gaussian noise - mask all valid (noise affects entire signal)"""
    if noise_level == 0:
        return X, np.ones((X.shape[0], X.shape[2]), dtype=bool)

    noise = np.random.normal(0, noise_level, X.shape).astype(np.float32)
    mask = np.ones((X.shape[0], X.shape[2]), dtype=bool)  # [B, T]
    return X + noise, mask

def add_temporal_mask(X, mask_ratio=0.3):
    """Add temporal masking - masked parts are False"""
    X_aug = X.copy()
    B, C, T = X.shape
    mask = np.ones((B, T), dtype=bool)  # [B, T]

    for i in range(B):
        if np.random.rand() < 0.5:
            mask_len = int(T * mask_ratio)
            start = np.random.randint(0, T - mask_len)
            X_aug[i, :, start:start+mask_len] = 0
            mask[i, start:start+mask_len] = False  # Mark corrupted region

    return X_aug, mask

def add_channel_drift(X, drift_std=0.2):
    """Channel drift - mask all valid"""
    X_aug = X.copy()
    B, C, T = X.shape
    mask = np.ones((B, T), dtype=bool)

    for i in range(B):
        for c in range(C):
            if np.random.rand() < 0.3:
                drift = np.random.normal(0, drift_std)
                X_aug[i, c, :] += drift

    return X_aug, mask

def add_channel_drop(X, drop_prob=0.2):
    """Channel drop - time axis mask remains valid"""
    X_aug = X.copy()
    B, C, T = X.shape
    mask = np.ones((B, T), dtype=bool)

    for i in range(B):
        drop_mask = np.random.rand(C) < drop_prob
        X_aug[i, drop_mask, :] = 0

    return X_aug, mask

def add_spike_noise(X, spike_prob=0.05, spike_magnitude=2.0):
    """Spike noise - mark spike locations in mask"""
    X_aug = X.copy()
    B, C, T = X.shape

    spike_mask_3d = np.random.rand(*X.shape) < spike_prob  # [B, C, T]
    spikes = np.random.randn(*X.shape) * spike_magnitude
    X_aug[spike_mask_3d] += spikes[spike_mask_3d]

    # Reduce to time axis: if any channel has spike, mark as False
    mask = ~(spike_mask_3d.any(axis=1))  # [B, T]

    return X_aug, mask

def apply_mixed_augmentation(X, max_noise=0.4):
    """Apply random combination of augmentations with mask tracking"""
    aug_type = np.random.choice([
        'gaussian', 'temporal', 'drift', 'drop', 'spike', 'mixed'
    ])

    if aug_type == 'gaussian':
        noise_level = np.random.uniform(0, max_noise)
        return add_gaussian_noise(X, noise_level)

    elif aug_type == 'temporal':
        mask_ratio = np.random.uniform(0.1, 0.4)
        return add_temporal_mask(X, mask_ratio)

    elif aug_type == 'drift':
        drift_std = np.random.uniform(0.1, 0.3)
        return add_channel_drift(X, drift_std)

    elif aug_type == 'drop':
        drop_prob = np.random.uniform(0.1, 0.3)
        return add_channel_drop(X, drop_prob)

    elif aug_type == 'spike':
        spike_prob = np.random.uniform(0.02, 0.08)
        return add_spike_noise(X, spike_prob)

    else:  # mixed
        X_aug = X.copy()
        B, T = X.shape[0], X.shape[2]
        mask = np.ones((B, T), dtype=bool)

        num_augs = np.random.randint(2, 4)
        augs = np.random.choice(
            ['gaussian', 'temporal', 'drift', 'spike'],
            size=num_augs, replace=False
        )

        for aug in augs:
            if aug == 'gaussian':
                X_aug, m = add_gaussian_noise(X_aug, np.random.uniform(0, max_noise*0.5))
            elif aug == 'temporal':
                X_aug, m = add_temporal_mask(X_aug, np.random.uniform(0.1, 0.2))
            elif aug == 'drift':
                X_aug, m = add_channel_drift(X_aug, np.random.uniform(0.05, 0.15))
            elif aug == 'spike':
                X_aug, m = add_spike_noise(X_aug, np.random.uniform(0.02, 0.05))

            mask = mask & m  # AND operation: only valid in all augmentations

        return X_aug, mask

# ========================
# 2) UCI-HAR Data Loader
# ========================
_RAW_CHANNELS = [
    ("Inertial Signals/total_acc_x_", "txt"), ("Inertial Signals/total_acc_y_", "txt"),
    ("Inertial Signals/total_acc_z_", "txt"), ("Inertial Signals/body_acc_x_", "txt"),
    ("Inertial Signals/body_acc_y_", "txt"), ("Inertial Signals/body_acc_z_", "txt"),
    ("Inertial Signals/body_gyro_x_", "txt"), ("Inertial Signals/body_gyro_y_", "txt"),
    ("Inertial Signals/body_gyro_z_", "txt"),
]
_LABEL_MAP = {1:"WALKING", 2:"WALKING_UPSTAIRS", 3:"WALKING_DOWNSTAIRS", 4:"SITTING", 5:"STANDING", 6:"LAYING"}
_CODE_TO_LABEL_NAME = {i-1: _LABEL_MAP[i] for i in _LABEL_MAP}

def _load_split_raw(root: str, split: str) -> Tuple[np.ndarray, np.ndarray]:
    assert split in ("train", "test")
    X_list = [np.loadtxt(os.path.join(root, split, p + split + "." + e))[..., None] for p, e in _RAW_CHANNELS]
    X = np.concatenate(X_list, axis=-1).transpose(0, 2, 1)
    y = np.loadtxt(os.path.join(root, split, f"y_{split}.txt")).astype(int)
    return X, y

class UCIHARInertial(Dataset):
    def __init__(self, root: str, split: str, mean=None, std=None,
                 preloaded_data: Tuple[np.ndarray, np.ndarray] | None = None,
                 indices: np.ndarray | None = None,
                 augment: bool = False, max_noise: float = 0.4, augment_prob: float = 0.7):
        super().__init__()

        if preloaded_data is not None:
            X, y = preloaded_data
        else:
            X, y = _load_split_raw(root, split)

        if indices is not None:
            X = X[indices]
            y = y[indices]

        self.X = X.astype(np.float32)
        self.y = (y - 1).astype(np.int64) if y.min() >= 1 else y.astype(np.int64)

        if mean is not None and std is not None:
            self.mean, self.std = mean, std
            if preloaded_data is None:
                self.X = (self.X - self.mean) / self.std
        else:
            self.mean = self.X.mean(axis=(0,2), keepdims=True).astype(np.float32)
            self.std = (self.X.std(axis=(0,2), keepdims=True) + 1e-6).astype(np.float32)
            self.X = ((self.X - self.mean) / self.std).astype(np.float32)

        self.augment = augment
        self.max_noise = max_noise
        self.augment_prob = augment_prob

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx].copy()
        y = self.y[idx]

        # Generate mask
        T = x.shape[1]  # Time dimension
        mask = np.ones(T, dtype=bool)  # Default: all valid

        if self.augment and np.random.rand() < self.augment_prob:
            # Apply augmentation and get mask
            x_aug, mask_aug = apply_mixed_augmentation(
                x[np.newaxis], self.max_noise
            )
            x = x_aug[0]
            mask = mask_aug[0]  # [T]

        return (
            torch.from_numpy(x).float(),
            torch.tensor(y, dtype=torch.long),
            torch.from_numpy(mask).bool()
        )

class UCIHARInertialWithMask(Dataset):
    """Dataset with pre-computed masks for testing"""
    def __init__(self, X, y, masks):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
        self.masks = torch.BoolTensor(masks)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.masks[idx]

# ========================
# 3) TPA Module
# ========================
class ProductionTPA(nn.Module):
    """Temporal Prototype Attention with optional mask support"""

    def __init__(self, dim, num_prototypes=16, seg_kernel=9, heads=4, dropout=0.1,
                 temperature=0.07, topk_ratio=0.25):
        super().__init__()
        assert dim % heads == 0

        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.num_prototypes = num_prototypes
        self.temperature = temperature
        self.topk_ratio = topk_ratio

        self.proto = nn.Parameter(torch.randn(num_prototypes, dim) * 0.02)

        pad = (seg_kernel - 1) // 2
        self.lowpass = nn.Conv1d(dim, dim, kernel_size=5, padding=2, groups=dim, bias=False)
        self.dw = nn.Conv1d(dim, dim, kernel_size=seg_kernel, padding=pad, groups=dim, bias=False)
        self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=False)

        self.pre_norm = nn.LayerNorm(dim)

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)

        self.fuse = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim)
        )

        self.conf_head = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim // 4, 1)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, return_confidence=False, use_mask=True):
        """
        x: [B, T, D]
        mask: [B, T] - True for valid frames
        use_mask: whether to use mask (for ablation study)
        """
        B, T, D = x.shape
        P = self.num_prototypes

        x_filtered = self.lowpass(x.transpose(1, 2)).transpose(1, 2)
        xloc = self.pw(self.dw(x_filtered.transpose(1, 2))).transpose(1, 2)
        xloc = self.pre_norm(xloc) + x

        if mask is not None and use_mask:
            float_mask = mask.float()
            float_mask_expanded = float_mask.unsqueeze(-1)
            xloc = xloc * float_mask_expanded

        K = self.k_proj(xloc)
        V = self.v_proj(xloc)

        Qp = self.q_proj(self.proto).unsqueeze(0).expand(B, -1, -1)

        def split_heads(t, length):
            return t.view(B, length, self.heads, self.head_dim).transpose(1, 2)

        Qh = split_heads(Qp, P)
        Kh = split_heads(K, T)
        Vh = split_heads(V, T)

        Qh = F.normalize(Qh, dim=-1)
        Kh = F.normalize(Kh, dim=-1)

        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / self.temperature

        if mask is not None and use_mask:
            boolean_mask = mask.bool()
            mask_attn = boolean_mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(~mask_attn, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        attn = self.dropout(attn)

        proto_tokens = torch.matmul(attn, Vh)
        proto_tokens = proto_tokens.transpose(1, 2).contiguous().view(B, P, D)

        topk = max(1, int(P * self.topk_ratio))
        vals, _ = torch.topk(proto_tokens, k=topk, dim=1)
        z_tpa = vals.mean(dim=1)

        z_tpa = self.fuse(z_tpa)
        z_tpa = self.out_proj(z_tpa)

        if mask is not None and use_mask:
            mask_expanded = float_mask.unsqueeze(-1).float()
            z_gap = (x * mask_expanded).sum(dim=1) / (mask_expanded.sum(dim=1) + 1e-9)
        else:
            z_gap = x.mean(dim=1)

        confidence = torch.sigmoid(self.conf_head(z_tpa))
        z = confidence * z_tpa + (1 - confidence) * z_gap

        if return_confidence:
            return z, confidence
        return z

# ========================
# 4) Model Definitions
# ========================
class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k, s=1, p=None, g=1):
        super().__init__()
        self.c = nn.Conv1d(c_in, c_out, k, s, k//2 if p is None else p, groups=g, bias=False)
        self.bn = nn.BatchNorm1d(c_out)
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.bn(self.c(x)))

class MultiPathCNN(nn.Module):
    def __init__(self, in_ch=9, d_model=128, branches=(3,5,9,15), stride=2):
        super().__init__()
        h = d_model // 2
        self.pre = ConvBNAct(in_ch, h, 1)
        self.branches = nn.ModuleList([nn.Sequential(ConvBNAct(h, h, k, stride, g=h), ConvBNAct(h, h, 1)) for k in branches])
        self.post = ConvBNAct(len(branches)*h, d_model, 1)
        self.stride = stride

    def forward(self, x):
        return self.post(torch.cat([b(self.pre(x)) for b in self.branches], dim=1))

class SimpleGAPHead(nn.Module):
    """Baseline: Global Average Pooling"""
    def __init__(self, d_model: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, Fmap):
        features = Fmap.transpose(1, 2)
        pooled = features.mean(dim=1)
        logits = self.fc(pooled)
        return logits

class TPAHead(nn.Module):
    """TPA: Temporal Prototype Attention (with optional mask support)"""
    def __init__(self, d_model: int, num_classes: int,
                 num_prototypes: int = 16, seg_kernel: int = 9,
                 heads: int = 4, dropout: float = 0.1,
                 temperature: float = 0.07, topk_ratio: float = 0.25,
                 use_mask: bool = False):
        super().__init__()
        self.use_mask = use_mask

        self.tpa = ProductionTPA(
            dim=d_model,
            num_prototypes=num_prototypes,
            seg_kernel=seg_kernel,
            heads=heads,
            dropout=dropout,
            temperature=temperature,
            topk_ratio=topk_ratio
        )

        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, Fmap, mask: torch.BoolTensor | None = None):
        """
        Fmap: [B, D, T]
        mask: [B, T]
        """
        features = Fmap.transpose(1, 2)
        z = self.tpa(features, mask=mask, return_confidence=False, use_mask=self.use_mask)
        logits = self.classifier(z)
        return logits

class HAR_Model(nn.Module):
    def __init__(self, d_model=128, num_classes=6, model_type='gap', use_mask=False, tpa_config=None):
        super().__init__()
        self.backbone = MultiPathCNN(d_model=d_model)
        self.model_type = model_type
        self.use_mask = use_mask

        if model_type == 'gap':
            self.head = SimpleGAPHead(d_model=d_model, num_classes=num_classes)
        else:  # tpa
            self.head = TPAHead(
                d_model=d_model,
                num_classes=num_classes,
                num_prototypes=tpa_config.get('num_prototypes', 16),
                seg_kernel=tpa_config.get('seg_kernel', 9),
                heads=tpa_config.get('heads', 4),
                dropout=tpa_config.get('dropout', 0.1),
                temperature=tpa_config.get('temperature', 0.07),
                topk_ratio=tpa_config.get('topk_ratio', 0.25),
                use_mask=use_mask
            )

    def forward(self, x, mask: torch.BoolTensor | None = None):
        fmap = self.backbone(x)

        if self.model_type == 'tpa' and mask is not None:
            stride = self.backbone.stride
            mask_float = mask.float().unsqueeze(1)
            mask_down = (F.avg_pool1d(mask_float, kernel_size=stride, stride=stride) == 1.0).squeeze(1)
            return self.head(fmap, mask_down)
        else:
            return self.head(fmap)

# ========================
# 5) Train / Eval
# ========================
def train_one_epoch(model, loader, opt, cfg: Config):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0

    for x, y, mask in loader:
        x = x.to(cfg.device).float()
        y = y.to(cfg.device)
        mask = mask.to(cfg.device).bool()

        opt.zero_grad(set_to_none=True)
        logits = model(x, mask=mask)

        loss = F.cross_entropy(logits, y, label_smoothing=cfg.label_smoothing)

        if torch.isnan(loss):
            print("  Warning: NaN loss detected, skipping batch")
            continue

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()

        with torch.no_grad():
            pred = logits.argmax(dim=-1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)

    return {
        "loss": loss_sum / total if total > 0 else 0,
        "acc": correct / total if total > 0 else 0,
    }

@torch.no_grad()
def evaluate(model, loader, cfg: Config, classes=6):
    model.eval()
    ys, ps = [], []

    for x, y, mask in loader:
        x = x.to(cfg.device)
        y = y.to(cfg.device)
        mask = mask.to(cfg.device).bool()

        logits = model(x, mask=mask)
        ps.append(logits.argmax(dim=-1).cpu().numpy())
        ys.append(y.cpu().numpy())

    y_true, y_pred = np.concatenate(ys), np.concatenate(ps)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')

    return acc, f1

# ========================
# 6) Robustness Analysis Functions
# ========================

def calculate_robustness_metrics(df_results):
    """Calculate robustness metrics based on performance drop"""
    robustness_metrics = []

    for model in df_results['Model'].unique():
        model_data = df_results[df_results['Model'] == model]

        # Clean accuracy (baseline)
        clean_acc = model_data[model_data['Noise Type'] == 'Clean']['Accuracy'].values[0]

        # Calculate drops for each noise type
        noise_drops = []
        for _, row in model_data.iterrows():
            if row['Noise Type'] != 'Clean':
                drop = clean_acc - row['Accuracy']
                drop_ratio = (drop / clean_acc) * 100  # Drop percentage
                noise_drops.append({
                    'Noise Type': row['Noise Type'],
                    'Accuracy': row['Accuracy'],
                    'Drop (%)': drop,
                    'Drop Ratio (%)': drop_ratio,
                    'Retention (%)': 100 - drop_ratio
                })

        noise_drops_df = pd.DataFrame(noise_drops)

        robustness_metrics.append({
            'Model': model,
            'Clean Acc': clean_acc,
            'Avg Drop (%)': noise_drops_df['Drop (%)'].mean(),
            'Max Drop (%)': noise_drops_df['Drop (%)'].max(),
            'Min Drop (%)': noise_drops_df['Drop (%)'].min(),
            'Std Drop (%)': noise_drops_df['Drop (%)'].std(),
            'Avg Drop Ratio (%)': noise_drops_df['Drop Ratio (%)'].mean(),
            'Avg Retention (%)': noise_drops_df['Retention (%)'].mean(),
            'Worst Retention (%)': noise_drops_df['Retention (%)'].min(),
            'noise_details': noise_drops
        })

    return pd.DataFrame(robustness_metrics)

def print_robustness_comparison(df_results):
    """Print comprehensive robustness comparison based on drop metrics"""

    print(f"\n{'='*100}")
    print("   ROBUSTNESS ANALYSIS: Performance Drop from Clean Baseline")
    print(f"{'='*100}\n")

    # Calculate metrics
    robustness_df = calculate_robustness_metrics(df_results)

    # 1. Overall Robustness Ranking
    print("📊 ROBUSTNESS RANKING (Lower drop = Better robustness)\n")
    print(f"{'Rank':<6} | {'Model':<20} | {'Clean':>8} | {'Avg Drop':>10} | {'Max Drop':>10} | {'Retention':>10}")
    print("-" * 100)

    ranked = robustness_df.sort_values('Avg Drop (%)')
    for idx, row in ranked.iterrows():
        rank = ranked.index.get_loc(idx) + 1
        print(f"{rank:<6} | {row['Model']:<20} | {row['Clean Acc']:7.2f}% | "
              f"{row['Avg Drop (%)']:9.2f}% | {row['Max Drop (%)']:9.2f}% | {row['Avg Retention (%)']:9.2f}%")

    print()

    # 2. Detailed Drop by Noise Type
    print(f"\n{'='*100}")
    print("   DETAILED DROP ANALYSIS BY NOISE TYPE")
    print(f"{'='*100}\n")

    models = df_results['Model'].unique()
    noise_types = [nt for nt in df_results['Noise Type'].unique() if nt != 'Clean']

    for noise_type in noise_types:
        print(f"\n🔍 {noise_type}")
        print(f"{'Model':<20} | {'Accuracy':>10} | {'Drop':>10} | {'Drop Ratio':>12} | {'Retention':>12} | {'Grade':<15}")
        print("-" * 100)

        noise_data = df_results[df_results['Noise Type'] == noise_type].copy()

        for model in models:
            model_row = df_results[df_results['Model'] == model]
            clean_acc = model_row[model_row['Noise Type'] == 'Clean']['Accuracy'].values[0]
            noise_row = noise_data[noise_data['Model'] == model]

            if len(noise_row) > 0:
                acc = noise_row['Accuracy'].values[0]
                drop = clean_acc - acc
                drop_ratio = (drop / clean_acc) * 100
                retention = 100 - drop_ratio

                # Grade based on retention
                if retention >= 95:
                    grade = "🟢 Excellent"
                elif retention >= 90:
                    grade = "🟡 Good"
                elif retention >= 85:
                    grade = "🟠 Fair"
                elif retention >= 80:
                    grade = "🔴 Poor"
                else:
                    grade = "⚫ Very Poor"

                print(f"{model:<20} | {acc:9.2f}% | {drop:9.2f}% | {drop_ratio:11.2f}% | "
                      f"{retention:11.2f}% | {grade}")

    # 3. Robustness Comparison Matrix
    print(f"\n{'='*100}")
    print("   PERFORMANCE DROP (%) MATRIX - Lower is Better")
    print(f"{'='*100}\n")

    drop_matrix = []
    for model in models:
        model_data = df_results[df_results['Model'] == model]
        clean_acc = model_data[model_data['Noise Type'] == 'Clean']['Accuracy'].values[0]

        row_data = {'Model': model}
        for noise_type in noise_types:
            noise_row = model_data[model_data['Noise Type'] == noise_type]
            if len(noise_row) > 0:
                acc = noise_row['Accuracy'].values[0]
                drop = clean_acc - acc
                row_data[noise_type] = drop

        drop_matrix.append(row_data)

    drop_df = pd.DataFrame(drop_matrix).set_index('Model')
    print(drop_df.to_string(float_format="%.2f"))
    print()

    # 4. Comparative Robustness Analysis
    print(f"\n{'='*100}")
    print("   COMPARATIVE ROBUSTNESS IMPROVEMENT")
    print(f"{'='*100}\n")

    baseline_model = 'GAP_Baseline'
    baseline_metrics = robustness_df[robustness_df['Model'] == baseline_model].iloc[0]

    print(f"Baseline Model: {baseline_model}")
    print(f"  - Clean Accuracy: {baseline_metrics['Clean Acc']:.2f}%")
    print(f"  - Average Drop: {baseline_metrics['Avg Drop (%)']:.2f}%")
    print(f"  - Average Retention: {baseline_metrics['Avg Retention (%)']:.2f}%\n")

    print("Improvements over Baseline:\n")
    print(f"{'Model':<20} | {'Drop Reduction':>15} | {'Retention Gain':>15} | {'Robustness Score':>18}")
    print("-" * 100)

    for _, row in robustness_df.iterrows():
        if row['Model'] != baseline_model:
            drop_reduction = baseline_metrics['Avg Drop (%)'] - row['Avg Drop (%)']
            retention_gain = row['Avg Retention (%)'] - baseline_metrics['Avg Retention (%)']

            # Robustness Score: combination of clean acc and retention
            robustness_score = (row['Clean Acc'] * 0.3 + row['Avg Retention (%)'] * 0.7)

            print(f"{row['Model']:<20} | {drop_reduction:+14.2f}% | {retention_gain:+14.2f}% | {robustness_score:17.2f}")

    print()

    # 5. Noise Type Vulnerability Analysis
    print(f"\n{'='*100}")
    print("   NOISE TYPE VULNERABILITY ANALYSIS")
    print(f"{'='*100}\n")

    print("Most Challenging Noise Types (Highest Average Drop):\n")

    noise_difficulty = []
    for noise_type in noise_types:
        noise_data = df_results[df_results['Noise Type'] == noise_type]
        avg_drops = []

        for model in models:
            model_row = df_results[df_results['Model'] == model]
            clean_acc = model_row[model_row['Noise Type'] == 'Clean']['Accuracy'].values[0]
            noise_row = noise_data[noise_data['Model'] == model]

            if len(noise_row) > 0:
                acc = noise_row['Accuracy'].values[0]
                drop = clean_acc - acc
                avg_drops.append(drop)

        noise_difficulty.append({
            'Noise Type': noise_type,
            'Avg Drop Across Models': np.mean(avg_drops),
            'Max Drop': np.max(avg_drops),
            'Min Drop': np.min(avg_drops)
        })

    difficulty_df = pd.DataFrame(noise_difficulty).sort_values('Avg Drop Across Models', ascending=False)

    print(f"{'Rank':<6} | {'Noise Type':<25} | {'Avg Drop':>10} | {'Max Drop':>10} | {'Min Drop':>10}")
    print("-" * 100)

    for idx, row in difficulty_df.iterrows():
        rank = difficulty_df.index.get_loc(idx) + 1
        print(f"{rank:<6} | {row['Noise Type']:<25} | {row['Avg Drop Across Models']:9.2f}% | "
              f"{row['Max Drop']:9.2f}% | {row['Min Drop']:9.2f}%")

    print()

    # 6. Model-Specific Strengths
    print(f"\n{'='*100}")
    print("   MODEL-SPECIFIC STRENGTHS & WEAKNESSES")
    print(f"{'='*100}\n")

    for model in models:
        print(f"\n📌 {model}")
        model_details = robustness_df[robustness_df['Model'] == model].iloc[0]
        noise_details = pd.DataFrame(model_details['noise_details'])

        # Best performance (lowest drop)
        best_noise = noise_details.loc[noise_details['Drop (%)'].idxmin()]
        worst_noise = noise_details.loc[noise_details['Drop (%)'].idxmax()]

        print(f"  Overall:")
        print(f"    - Clean Accuracy: {model_details['Clean Acc']:.2f}%")
        print(f"    - Average Drop: {model_details['Avg Drop (%)']:.2f}%")
        print(f"    - Average Retention: {model_details['Avg Retention (%)']:.2f}%")
        print(f"  Strongest Against: {best_noise['Noise Type']} (Drop: {best_noise['Drop (%)']:.2f}%)")
        print(f"  Weakest Against: {worst_noise['Noise Type']} (Drop: {worst_noise['Drop (%)']:.2f}%)")

    print()

    return robustness_df

def print_drop_heatmap(df_results):
    """Print drop percentage heatmap with color coding"""
    print(f"\n{'='*100}")
    print("   DROP PERCENTAGE HEATMAP")
    print(f"{'='*100}\n")

    models = df_results['Model'].unique()
    noise_types = [nt for nt in df_results['Noise Type'].unique() if nt != 'Clean']

    # Print header
    print(f"{'Noise Type':<25} | ", end="")
    for model in models:
        print(f"{model:>20} | ", end="")
    print()
    print("-" * 100)

    # Print drops with visual indicators
    for noise_type in noise_types:
        print(f"{noise_type:<25} | ", end="")

        for model in models:
            model_row = df_results[df_results['Model'] == model]
            clean_acc = model_row[model_row['Noise Type'] == 'Clean']['Accuracy'].values[0]
            noise_row = model_row[model_row['Noise Type'] == noise_type]

            if len(noise_row) > 0:
                acc = noise_row['Accuracy'].values[0]
                drop = clean_acc - acc
                drop_ratio = (drop / clean_acc) * 100

                # Visual indicator
                if drop < 1:
                    indicator = "🟢"
                elif drop < 3:
                    indicator = "🟡"
                elif drop < 5:
                    indicator = "🟠"
                else:
                    indicator = "🔴"

                print(f"{drop:6.2f}% ({drop_ratio:5.1f}%) {indicator} | ", end="")

        print()

    print("\nLegend: 🟢 <1% | 🟡 1-3% | 🟠 3-5% | 🔴 >5%")
    print()

def save_robustness_report(df_results, robustness_df, save_dir):
    """Save comprehensive robustness report"""

    # 1. Save detailed CSV
    detailed_results = []

    for model in df_results['Model'].unique():
        model_data = df_results[df_results['Model'] == model]
        clean_acc = model_data[model_data['Noise Type'] == 'Clean']['Accuracy'].values[0]

        for _, row in model_data.iterrows():
            if row['Noise Type'] != 'Clean':
                drop = clean_acc - row['Accuracy']
                drop_ratio = (drop / clean_acc) * 100
                retention = 100 - drop_ratio

                detailed_results.append({
                    'Model': model,
                    'Noise Type': row['Noise Type'],
                    'Clean Accuracy': clean_acc,
                    'Noisy Accuracy': row['Accuracy'],
                    'Drop (%)': drop,
                    'Drop Ratio (%)': drop_ratio,
                    'Retention (%)': retention,
                    'F1': row['F1']
                })

    detailed_df = pd.DataFrame(detailed_results)
    detailed_df.to_csv(os.path.join(save_dir, 'robustness_detailed.csv'), index=False)

    # 2. Save summary JSON
    summary = {}
    for _, row in robustness_df.iterrows():
        summary[row['Model']] = {
            'clean_accuracy': float(row['Clean Acc']),
            'average_drop': float(row['Avg Drop (%)']),
            'max_drop': float(row['Max Drop (%)']),
            'min_drop': float(row['Min Drop (%)']),
            'std_drop': float(row['Std Drop (%)']),
            'average_retention': float(row['Avg Retention (%)']),
            'worst_retention': float(row['Worst Retention (%)'])
        }

    with open(os.path.join(save_dir, 'robustness_summary.json'), 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"✓ Robustness reports saved to:")
    print(f"  - {os.path.join(save_dir, 'robustness_detailed.csv')}")
    print(f"  - {os.path.join(save_dir, 'robustness_summary.json')}")

# ========================
# 7) Noise Robustness Study
# ========================
def run_noise_robustness_study(cfg: Config):
    os.makedirs(cfg.save_dir, exist_ok=True)

    # Load data
    print(f"\n{'='*70}")
    print("   DATA PREPARATION")
    print(f"{'='*70}")

    X_full, y_full = _load_split_raw(cfg.data_dir, "train")
    mean = X_full.mean(axis=(0,2), keepdims=True)
    std = X_full.std(axis=(0,2), keepdims=True) + 1e-6
    X_full = ((X_full - mean) / std).astype(np.float32)

    indices = np.arange(len(X_full))
    train_indices, val_indices = train_test_split(
        indices,
        test_size=cfg.val_split,
        random_state=SEED,
        stratify=y_full
    )

    print(f"Total samples: {len(X_full)}")
    print(f"  → Train: {len(train_indices)} ({(1-cfg.val_split)*100:.0f}%)")
    print(f"  → Val:   {len(val_indices)} ({cfg.val_split*100:.0f}%)")

    # Create datasets
    train_set = UCIHARInertial(
        cfg.data_dir, "train",
        mean=mean, std=std,
        preloaded_data=(X_full, y_full),
        indices=train_indices,
        augment=True,
        max_noise=cfg.max_train_noise,
        augment_prob=cfg.train_augment_prob
    )

    val_set = UCIHARInertial(
        cfg.data_dir, "train",
        mean=mean, std=std,
        preloaded_data=(X_full, y_full),
        indices=val_indices,
        augment=False
    )

    test_set_orig = UCIHARInertial(cfg.data_dir, "test", mean=mean, std=std, augment=False)

    val_loader = DataLoader(val_set, cfg.batch_size, num_workers=cfg.num_workers)

    # Define noise configurations for testing
    noise_configs = [
        {'name': 'Clean', 'type': 'none', 'level': 0.0},
        {'name': '20% Gaussian', 'type': 'gaussian', 'level': 0.2},
        {'name': '40% Gaussian', 'type': 'gaussian', 'level': 0.4},
        {'name': '30% Temporal Mask', 'type': 'temporal', 'level': 0.3},
        {'name': '20% Channel Drift', 'type': 'drift', 'level': 0.2},
        {'name': '20% Channel Drop', 'type': 'drop', 'level': 0.2},
        {'name': '5% Spike Noise', 'type': 'spike', 'level': 0.05},
    ]

    # Define 3 model configurations
    model_configs = [
        {"name": "GAP_Baseline", "model_type": "gap", "use_mask": False},
        {"name": "TPA", "model_type": "tpa", "use_mask": False},
        {"name": "TPA_WithMask", "model_type": "tpa", "use_mask": True},
    ]

    print(f"\n{'='*70}")
    print("    NOISE ROBUSTNESS STUDY: 3-Model Comparison")
    print(f"{'='*70}")
    print(f"   Models: GAP (Baseline), TPA, TPA+Mask")
    print(f"   Training: Mixed noise augmentation (prob={cfg.train_augment_prob}, max={cfg.max_train_noise})")
    print(f"   Testing: 7 noise scenarios")
    print(f"{'='*70}\n")

    results_table = []

    # Train and evaluate each model
    for model_cfg in model_configs:
        print(f"\n{'='*70}")
        print(f"   MODEL: {model_cfg['name']}")
        print(f"{'='*70}")

        random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

        tpa_config = {
            'num_prototypes': cfg.tpa_num_prototypes,
            'seg_kernel': cfg.tpa_seg_kernel,
            'heads': cfg.tpa_heads,
            'dropout': cfg.tpa_dropout,
            'temperature': cfg.tpa_temperature,
            'topk_ratio': cfg.tpa_topk_ratio
        }

        model = HAR_Model(
            d_model=cfg.d_model,
            model_type=model_cfg["model_type"],
            use_mask=model_cfg["use_mask"],
            tpa_config=tpa_config
        ).to(cfg.device).float()

        g = torch.Generator(device='cpu').manual_seed(SEED)
        train_loader = DataLoader(train_set, cfg.batch_size, shuffle=True,
                                 num_workers=cfg.num_workers, generator=g)

        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        best_acc, best_wts = 0.0, None
        patience_counter = 0
        best_epoch = 0

        print(f"\nTraining {model_cfg['name']} for up to {cfg.epochs} epochs (patience={cfg.patience})...")

        for epoch in range(1, cfg.epochs + 1):
            stats = train_one_epoch(model, train_loader, opt, cfg)
            val_acc, val_f1 = evaluate(model, val_loader, cfg)

            improved = False
            if val_acc > best_acc + cfg.min_delta:
                best_acc = val_acc
                best_wts = copy.deepcopy(model.state_dict())
                patience_counter = 0
                best_epoch = epoch
                improved = True
            else:
                patience_counter += 1

            log_str = f"[{epoch:02d}/{cfg.epochs}] Train L:{stats['loss']:.4f} A:{stats['acc']:.4f}"
            log_str += f" | Val A:{val_acc:.4f} F1:{val_f1:.4f}"
            if improved:
                log_str += " ✓"

            if epoch % 10 == 0 or epoch == 1:
                print(log_str)

            if patience_counter >= cfg.patience:
                print(f"\n⚠ Early stopping triggered at epoch {epoch}")
                print(f"  Best validation acc: {best_acc:.4f} (epoch {best_epoch})")
                break

        if best_wts:
            model_path = os.path.join(cfg.save_dir, f"model_{model_cfg['name']}.pth")
            torch.save(best_wts, model_path)
            model.load_state_dict(best_wts)
            print(f"\n✓ Best Val Acc: {best_acc:.4f} (epoch {best_epoch})")

        # Test on all noise scenarios
        print(f"\n   Testing on {len(noise_configs)} noise scenarios...")
        print(f"\n{'Noise Type':<25} | {'Accuracy':>8} | {'F1':>6}")
        print("-" * 50)

        for noise_cfg in noise_configs:
            X_test = test_set_orig.X.copy()

            # Apply noise and generate mask
            if noise_cfg['type'] != 'none':
                if noise_cfg['type'] == 'gaussian':
                    X_test, test_masks = add_gaussian_noise(X_test, noise_cfg['level'])
                elif noise_cfg['type'] == 'temporal':
                    X_test, test_masks = add_temporal_mask(X_test, noise_cfg['level'])
                elif noise_cfg['type'] == 'drift':
                    X_test, test_masks = add_channel_drift(X_test, noise_cfg['level'])
                elif noise_cfg['type'] == 'drop':
                    X_test, test_masks = add_channel_drop(X_test, noise_cfg['level'])
                elif noise_cfg['type'] == 'spike':
                    X_test, test_masks = add_spike_noise(X_test, noise_cfg['level'])
            else:
                # Clean - all masks True
                test_masks = np.ones((X_test.shape[0], X_test.shape[2]), dtype=bool)

            # Create test dataset with masks
            test_ds_noisy = UCIHARInertialWithMask(
                X_test, test_set_orig.y, test_masks
            )
            test_loader_noisy = DataLoader(test_ds_noisy, cfg.batch_size,
                                          num_workers=cfg.num_workers)

            acc, f1 = evaluate(model, test_loader_noisy, cfg)

            print(f"{noise_cfg['name']:<25} | {acc*100:7.2f}% | {f1:5.3f}")

            results_table.append({
                'Model': model_cfg['name'],
                'Noise Type': noise_cfg['name'],
                'Accuracy': acc * 100,
                'F1': f1
            })

        print()

    # ========================
    # Analysis with Drop Metrics
    # ========================
    print(f"\n{'='*100}")
    print("   COMPREHENSIVE ROBUSTNESS ANALYSIS")
    print(f"{'='*100}\n")

    df_results = pd.DataFrame(results_table)

    # 1. Print robustness comparison
    robustness_df = print_robustness_comparison(df_results)

    # 2. Print drop heatmap
    print_drop_heatmap(df_results)

    # 3. Original pivot table (for reference)
    pivot_acc = df_results.pivot(index='Noise Type', columns='Model', values='Accuracy')
    print(f"\n{'='*100}")
    print("   ACCURACY (%) REFERENCE TABLE")
    print(f"{'='*100}\n")
    print(pivot_acc.to_string(float_format="%.2f"))
    print()

    # 4. Save reports
    save_robustness_report(df_results, robustness_df, cfg.save_dir)

    # 5. Key findings summary
    print(f"\n{'='*100}")
    print("   KEY FINDINGS SUMMARY")
    print(f"{'='*100}\n")

    best_robust = robustness_df.loc[robustness_df['Avg Drop (%)'].idxmin()]
    worst_robust = robustness_df.loc[robustness_df['Avg Drop (%)'].idxmax()]

    print(f"🏆 Most Robust Model: {best_robust['Model']}")
    print(f"   - Average Drop: {best_robust['Avg Drop (%)']:.2f}%")
    print(f"   - Average Retention: {best_robust['Avg Retention (%)']:.2f}%")
    print(f"   - Clean Accuracy: {best_robust['Clean Acc']:.2f}%")
    print()

    print(f"⚠️  Least Robust Model: {worst_robust['Model']}")
    print(f"   - Average Drop: {worst_robust['Avg Drop (%)']:.2f}%")
    print(f"   - Average Retention: {worst_robust['Avg Retention (%)']:.2f}%")
    print(f"   - Clean Accuracy: {worst_robust['Clean Acc']:.2f}%")
    print()

    # Comparative improvement
    if 'GAP_Baseline' in robustness_df['Model'].values:
        baseline = robustness_df[robustness_df['Model'] == 'GAP_Baseline'].iloc[0]

        print("📈 Improvements over Baseline:")
        for _, row in robustness_df.iterrows():
            if row['Model'] != 'GAP_Baseline':
                drop_improvement = baseline['Avg Drop (%)'] - row['Avg Drop (%)']
                print(f"   {row['Model']}: {drop_improvement:+.2f}% drop reduction")

    print(f"\n{'='*100}\n")

    # Save main results
    df_results.to_csv(os.path.join(cfg.save_dir, 'noise_robustness_results.csv'), index=False)

    return df_results, robustness_df

# ========================
# 8) Visualization
# ========================
def plot_noise_robustness_comparison(df_results, robustness_df, save_path):
    """Create comprehensive visualization of noise robustness results"""
    import matplotlib.pyplot as plt
    import seaborn as sns

    sns.set_style("whitegrid")
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Noise Robustness Study: 3-Model Comparison', fontsize=16, fontweight='bold')

    # Plot 1: Drop by noise type
    models = df_results['Model'].unique()
    noise_types = [nt for nt in df_results['Noise Type'].unique() if nt != 'Clean']

    drop_data = []
    for model in models:
        model_data = df_results[df_results['Model'] == model]
        clean_acc = model_data[model_data['Noise Type'] == 'Clean']['Accuracy'].values[0]

        for noise_type in noise_types:
            noise_row = model_data[model_data['Noise Type'] == noise_type]
            if len(noise_row) > 0:
                acc = noise_row['Accuracy'].values[0]
                drop = clean_acc - acc
                drop_data.append({'Model': model, 'Noise Type': noise_type, 'Drop': drop})

    drop_df = pd.DataFrame(drop_data)
    pivot_drop = drop_df.pivot(index='Noise Type', columns='Model', values='Drop')
    pivot_drop.plot(kind='bar', ax=axes[0, 0], width=0.8, edgecolor='black')
    axes[0, 0].set_title('Performance Drop by Noise Type (Lower is Better)', fontsize=14)
    axes[0, 0].set_ylabel('Drop (%)', fontsize=12)
    axes[0, 0].set_xlabel('Noise Type', fontsize=12)
    axes[0, 0].legend(title='Model', loc='upper left')
    axes[0, 0].grid(axis='y', alpha=0.3)
    axes[0, 0].tick_params(axis='x', rotation=45)

    # Plot 2: Average drop comparison
    avg_drops = robustness_df.set_index('Model')['Avg Drop (%)']
    colors = ['#e74c3c', '#3498db', '#2ecc71']
    avg_drops.plot(kind='bar', ax=axes[0, 1], color=colors, edgecolor='black', width=0.6)
    axes[0, 1].set_title('Average Drop Across All Noise Types', fontsize=14)
    axes[0, 1].set_ylabel('Average Drop (%)', fontsize=12)
    axes[0, 1].set_xlabel('Model', fontsize=12)
    axes[0, 1].tick_params(axis='x', rotation=45)
    axes[0, 1].grid(axis='y', alpha=0.3)
    for i, v in enumerate(avg_drops):
        axes[0, 1].text(i, v + 0.2, f"{v:.2f}%", ha='center', fontweight='bold')

    # Plot 3: Retention rates
    retention_data = []
    for model in models:
        model_data = df_results[df_results['Model'] == model]
        clean_acc = model_data[model_data['Noise Type'] == 'Clean']['Accuracy'].values[0]

        for noise_type in noise_types:
            noise_row = model_data[model_data['Noise Type'] == noise_type]
            if len(noise_row) > 0:
                acc = noise_row['Accuracy'].values[0]
                retention = (acc / clean_acc) * 100
                retention_data.append({'Model': model, 'Noise Type': noise_type, 'Retention': retention})

    retention_df = pd.DataFrame(retention_data)
    pivot_retention = retention_df.pivot(index='Noise Type', columns='Model', values='Retention')
    pivot_retention.plot(kind='line', ax=axes[1, 0], marker='o', linewidth=2, markersize=8)
    axes[1, 0].set_title('Retention Rate by Noise Type (Higher is Better)', fontsize=14)
    axes[1, 0].set_ylabel('Retention (%)', fontsize=12)
    axes[1, 0].set_xlabel('Noise Type', fontsize=12)
    axes[1, 0].legend(title='Model', loc='lower left')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].tick_params(axis='x', rotation=45)
    axes[1, 0].axhline(y=90, color='red', linestyle='--', alpha=0.5, label='90% threshold')

    # Plot 4: Clean vs Average Noisy Accuracy
    clean_noisy_data = []
    for model in models:
        model_data = df_results[df_results['Model'] == model]
        clean_acc = model_data[model_data['Noise Type'] == 'Clean']['Accuracy'].values[0]
        noisy_accs = model_data[model_data['Noise Type'] != 'Clean']['Accuracy']
        avg_noisy = noisy_accs.mean()
        clean_noisy_data.append({'Model': model, 'Clean': clean_acc, 'Avg Noisy': avg_noisy})

    cn_df = pd.DataFrame(clean_noisy_data).set_index('Model')
    cn_df.plot(kind='bar', ax=axes[1, 1], width=0.7, edgecolor='black')
    axes[1, 1].set_title('Clean vs Average Noisy Accuracy', fontsize=14)
    axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12)
    axes[1, 1].set_xlabel('Model', fontsize=12)
    axes[1, 1].legend(title='Condition')
    axes[1, 1].grid(axis='y', alpha=0.3)
    axes[1, 1].tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Visualization saved to: {save_path}")

# ========================
# 9) Main Execution
# ========================
if __name__ == "__main__":
    config = Config()
    config.epochs = 100
    config.lr = 1e-4

    # Training augmentation settings
    config.train_augment_prob = 0.7
    config.max_train_noise = 0.4

    # Early stopping
    config.patience = 20
    config.min_delta = 0.0001
    config.val_split = 0.2

    # TPA hyperparameters
    config.tpa_num_prototypes = 16
    config.tpa_seg_kernel = 9
    config.tpa_heads = 4
    config.tpa_dropout = 0.1
    config.tpa_temperature = 0.07
    config.tpa_topk_ratio = 0.25

    print("\n" + "="*70)
    print("    NOISE ROBUSTNESS STUDY")
    print("="*70)
    print(f"Device: {config.device}")
    print(f"Epochs: {config.epochs}")
    print(f"Learning Rate: {config.lr}")
    print(f"Training Augmentation: prob={config.train_augment_prob}, max_noise={config.max_train_noise}")
    print(f"\n3 Models to Compare:")
    print(f"  1) GAP_Baseline:  Standard GAP (baseline)")
    print(f"  2) TPA:           Temporal Prototype Attention")
    print(f"  3) TPA_WithMask:  TPA + mask filtering")
    print(f"\n7 Noise Scenarios:")
    print(f"  - Clean (no noise)")
    print(f"  - Gaussian noise (20%, 40%)")
    print(f"  - Temporal mask (30%)")
    print(f"  - Channel drift (20%)")
    print(f"  - Channel drop (20%)")
    print(f"  - Spike noise (5%)")
    print(f"\nTPA Configuration:")
    print(f"  Prototypes: {config.tpa_num_prototypes}")
    print(f"  Heads: {config.tpa_heads}")
    print(f"  Temperature: {config.tpa_temperature}")
    print(f"  TopK Ratio: {config.tpa_topk_ratio}")
    print("="*70 + "\n")

    # Run study
    df_results, robustness_df = run_noise_robustness_study(config)

    # Generate visualization
    plot_path = os.path.join(config.save_dir, "noise_robustness_comparison.png")
    plot_noise_robustness_comparison(df_results, robustness_df, plot_path)

    print("\n" + "="*70)
    print("STUDY COMPLETED!")
    print("="*70)
    print("\nGenerated files:")
    print(f"  1. {os.path.join(config.save_dir, 'noise_robustness_results.csv')}")
    print(f"  2. {os.path.join(config.save_dir, 'robustness_detailed.csv')}")
    print(f"  3. {os.path.join(config.save_dir, 'robustness_summary.json')}")
    print(f"  4. {os.path.join(config.save_dir, 'noise_robustness_comparison.png')}")
    print(f"  5. model_*.pth files for all 3 models")
    print("="*70 + "\n")


    NOISE ROBUSTNESS STUDY
Device: cuda
Epochs: 100
Learning Rate: 0.0001
Training Augmentation: prob=0.7, max_noise=0.4

3 Models to Compare:
  1) GAP_Baseline:  Standard GAP (baseline)
  2) TPA:           Temporal Prototype Attention
  3) TPA_WithMask:  TPA + mask filtering

7 Noise Scenarios:
  - Clean (no noise)
  - Gaussian noise (20%, 40%)
  - Temporal mask (30%)
  - Channel drift (20%)
  - Channel drop (20%)
  - Spike noise (5%)

TPA Configuration:
  Prototypes: 16
  Heads: 4
  Temperature: 0.07
  TopK Ratio: 0.25


   DATA PREPARATION
Total samples: 7352
  → Train: 5881 (80%)
  → Val:   1471 (20%)

    NOISE ROBUSTNESS STUDY: 3-Model Comparison
   Models: GAP (Baseline), TPA, TPA+Mask
   Training: Mixed noise augmentation (prob=0.7, max=0.4)
   Testing: 7 noise scenarios


   MODEL: GAP_Baseline

Training GAP_Baseline for up to 100 epochs (patience=20)...
[01/100] Train L:1.4717 A:0.5215 | Val A:0.8192 F1:0.8030 ✓
[10/100] Train L:0.4458 A:0.9299 | Val A:0.9409 F1:0.9444
[20/1