In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# -*- coding: utf-8 -*-
"""
UCI-HAR Ablation Study: GAP vs Pure TPA (No Conv layers)
"""
import os, random, math, sys, time, copy, json
import numpy as np
from typing import Tuple, List
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# ========================
# 0) Config & Reproducibility
# ========================
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

@dataclass
class Config:
    mode: str = "ablation"
    data_dir: str = "/content/drive/MyDrive/AI_data/UCI_HAR_Dataset/UCI HAR Dataset"
    save_dir: str = "/content/drive/MyDrive/AI_data/ablation_pure_tpa"

    epochs: int = 25
    batch_size: int = 128
    lr: float = 1e-4
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    label_smoothing: float = 0.05

    # Early stopping
    patience: int = 20
    min_delta: float = 0.0001
    val_split: float = 0.2

    train_augment_prob: float = 0.25
    train_augment_mix: float = 0.35

    d_model: int = 128
    use_tpa: bool = False

    # TPA hyperparameters
    tpa_num_prototypes: int = 16
    tpa_heads: int = 4
    tpa_dropout: float = 0.1
    tpa_temperature: float = 0.07
    tpa_topk_ratio: float = 0.25

    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2

# ========================
# 1) UCI-HAR Data Loader
# ========================
_RAW_CHANNELS = [
    ("Inertial Signals/total_acc_x_", "txt"), ("Inertial Signals/total_acc_y_", "txt"), ("Inertial Signals/total_acc_z_", "txt"),
    ("Inertial Signals/body_acc_x_", "txt"), ("Inertial Signals/body_acc_y_", "txt"), ("Inertial Signals/body_acc_z_", "txt"),
    ("Inertial Signals/body_gyro_x_", "txt"), ("Inertial Signals/body_gyro_y_", "txt"), ("Inertial Signals/body_gyro_z_", "txt"),
]
_LABEL_MAP = {1:"WALKING", 2:"WALKING_UPSTAIRS", 3:"WALKING_DOWNSTAIRS", 4:"SITTING", 5:"STANDING", 6:"LAYING"}
_CODE_TO_LABEL_NAME = {i-1: _LABEL_MAP[i] for i in _LABEL_MAP}
_LABEL_NAME_TO_CODE = {v: k for k, v in _CODE_TO_LABEL_NAME.items()}

def _load_split_raw(root: str, split: str) -> Tuple[np.ndarray, np.ndarray]:
    assert split in ("train", "test")
    X_list = [np.loadtxt(os.path.join(root, split, p + split + "." + e))[..., None] for p, e in _RAW_CHANNELS]
    X = np.concatenate(X_list, axis=-1).transpose(0, 2, 1)
    y = np.loadtxt(os.path.join(root, split, f"y_{split}.txt")).astype(int)
    return X, y

class UCIHARInertial(Dataset):
    def __init__(self, root: str, split: str, mean=None, std=None,
                 preloaded_data: Tuple[np.ndarray, np.ndarray] | None = None,
                 indices: np.ndarray | None = None):
        super().__init__()

        if preloaded_data is not None:
            X, y = preloaded_data
        else:
            X, y = _load_split_raw(root, split)

        if indices is not None:
            X = X[indices]
            y = y[indices]

        self.X = X.astype(np.float32)
        self.y = (y - 1).astype(np.int64) if y.min() >= 1 else y.astype(np.int64)

        if mean is not None and std is not None:
            self.mean, self.std = mean, std
            if preloaded_data is None:
                self.X = (self.X - self.mean) / self.std
        else:
            self.mean = self.X.mean(axis=(0,2), keepdims=True).astype(np.float32)
            self.std = (self.X.std(axis=(0,2), keepdims=True) + 1e-6).astype(np.float32)
            self.X = ((self.X - self.mean) / self.std).astype(np.float32)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X[idx]).float(),
            torch.tensor(self.y[idx], dtype=torch.long)
        )

# ========================
# 2) Online Transition Augmentation
# ========================
def apply_transition_augmentation(x: torch.Tensor, y: torch.Tensor, mix_ratio: float = 0.25) -> torch.Tensor:
    B, C, T = x.shape

    mix_pts = int(T * mix_ratio)

    for i in range(B):
        if random.random() < 0.5:
            other_class_indices = (y != y[i]).nonzero(as_tuple=True)[0]
            if len(other_class_indices) > 0:
                j = other_class_indices[random.randint(0, len(other_class_indices)-1)]
                x[i, :, -mix_pts:] = x[j, :, :mix_pts].clone()

    return x

# ========================
# 3) Pure TPA Module (No Conv layers)
# ========================
class PureTPA(nn.Module):
    """Pure Temporal Prototype Attention - No convolutional preprocessing"""

    def __init__(self, dim, num_prototypes=16, heads=4, dropout=0.1,
                 temperature=0.07, topk_ratio=0.25):
        super().__init__()
        assert dim % heads == 0

        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.num_prototypes = num_prototypes
        self.temperature = temperature
        self.topk_ratio = topk_ratio

        # Learnable prototypes
        self.proto = nn.Parameter(torch.randn(num_prototypes, dim) * 0.02)

        # Simple normalization (replace conv-based preprocessing)
        self.pre_norm = nn.LayerNorm(dim)

        # Q, K, V projections for multi-head attention
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)

        # Fusion layer
        self.fuse = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim)
        )

        # Confidence predictor
        self.conf_head = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim // 4, 1)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_confidence=False):
        """
        x: [B, T, D]
        """
        B, T, D = x.shape
        P = self.num_prototypes

        # Simple normalization instead of conv preprocessing
        x_norm = self.pre_norm(x)

        # Generate K, V from normalized input
        K = self.k_proj(x_norm)
        V = self.v_proj(x_norm)

        # Generate Q from prototypes
        Qp = self.q_proj(self.proto).unsqueeze(0).expand(B, -1, -1)

        def split_heads(t, length):
            return t.view(B, length, self.heads, self.head_dim).transpose(1, 2)

        # Multi-head attention: [B, H, P/T, D/H]
        Qh = split_heads(Qp, P)
        Kh = split_heads(K, T)
        Vh = split_heads(V, T)

        # Normalize for stable attention
        Qh = F.normalize(Qh, dim=-1)
        Kh = F.normalize(Kh, dim=-1)

        # Attention scores: [B, H, P, T]
        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / self.temperature

        # Attention weights
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        attn = self.dropout(attn)

        # Aggregate temporal features to prototypes: [B, H, P, D/H]
        proto_tokens = torch.matmul(attn, Vh)
        proto_tokens = proto_tokens.transpose(1, 2).contiguous().view(B, P, D)

        # Select top-k prototypes
        topk = max(1, int(P * self.topk_ratio))
        vals, _ = torch.topk(proto_tokens, k=topk, dim=1)
        z_tpa = vals.mean(dim=1)  # [B, D]

        # Refine with fusion layer
        z_tpa = self.fuse(z_tpa)
        z_tpa = self.out_proj(z_tpa)

        # Compute GAP for fallback
        z_gap = x.mean(dim=1)

        # Confidence-weighted combination
        confidence = torch.sigmoid(self.conf_head(z_tpa))
        z = confidence * z_tpa + (1 - confidence) * z_gap

        if return_confidence:
            return z, confidence
        return z

# ========================
# 4) Model Definitions
# ========================
class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k, s=1, p=None, g=1):
        super().__init__()
        self.c = nn.Conv1d(c_in, c_out, k, s, k//2 if p is None else p, groups=g, bias=False)
        self.bn = nn.BatchNorm1d(c_out)
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.bn(self.c(x)))

class MultiPathCNN(nn.Module):
    def __init__(self, in_ch=9, d_model=128, branches=(3,5,9,15), stride=2):
        super().__init__()
        h = d_model // 2
        self.pre = ConvBNAct(in_ch, h, 1)
        self.branches = nn.ModuleList([nn.Sequential(ConvBNAct(h, h, k, stride, g=h), ConvBNAct(h, h, 1)) for k in branches])
        self.post = ConvBNAct(len(branches)*h, d_model, 1)
        self.stride = stride

    def forward(self, x):
        return self.post(torch.cat([b(self.pre(x)) for b in self.branches], dim=1))

class SimpleGAPHead(nn.Module):
    """Baseline: Global Average Pooling"""
    def __init__(self, d_model: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, Fmap):
        # [B, D, T] -> [B, T, D]
        features = Fmap.transpose(1, 2)
        pooled = features.mean(dim=1)
        logits = self.fc(pooled)
        aux = {"confidence": None}
        return logits, aux

class PureTPAHead(nn.Module):
    """Pure TPA: Temporal Prototype Attention (No Conv preprocessing)"""
    def __init__(self, d_model: int, num_classes: int,
                 num_prototypes: int = 16, heads: int = 4, dropout: float = 0.1,
                 temperature: float = 0.07, topk_ratio: float = 0.25):
        super().__init__()

        self.tpa = PureTPA(
            dim=d_model,
            num_prototypes=num_prototypes,
            heads=heads,
            dropout=dropout,
            temperature=temperature,
            topk_ratio=topk_ratio
        )

        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, Fmap):
        """
        Fmap: [B, D, T]
        """
        features = Fmap.transpose(1, 2)  # [B, T, D]
        z, confidence = self.tpa(features, return_confidence=True)
        logits = self.classifier(z)
        aux = {"confidence": confidence.mean().item()}
        return logits, aux

class HAR_Model(nn.Module):
    def __init__(self, d_model=128, num_classes=6, use_tpa=False, tpa_config=None):
        super().__init__()
        self.backbone = MultiPathCNN(d_model=d_model)
        self.use_tpa = use_tpa

        if use_tpa:
            self.head = PureTPAHead(
                d_model=d_model,
                num_classes=num_classes,
                num_prototypes=tpa_config.get('num_prototypes', 16),
                heads=tpa_config.get('heads', 4),
                dropout=tpa_config.get('dropout', 0.1),
                temperature=tpa_config.get('temperature', 0.07),
                topk_ratio=tpa_config.get('topk_ratio', 0.25)
            )
        else:
            self.head = SimpleGAPHead(d_model=d_model, num_classes=num_classes)

    def forward(self, x):
        fmap = self.backbone(x)
        return self.head(fmap)

# ========================
# 5) Train / Eval
# ========================
def train_one_epoch(model, loader, opt, cfg: Config, verbose_epoch: bool = False):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0
    aug_count = 0
    confidence_vals = []

    for x, y in loader:
        x, y = x.to(cfg.device).float(), y.to(cfg.device)

        if random.random() < cfg.train_augment_prob:
            x = apply_transition_augmentation(x, y, cfg.train_augment_mix)
            aug_count += 1

        opt.zero_grad(set_to_none=True)
        logits, aux = model(x)

        cls_loss = F.cross_entropy(logits, y, label_smoothing=cfg.label_smoothing)
        loss = cls_loss
        if torch.isnan(loss):
            if verbose_epoch:
                print("  Warning: NaN loss detected, skipping batch")
            continue

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()

        with torch.no_grad():
            pred = logits.argmax(dim=-1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)
            if aux["confidence"] is not None:
                confidence_vals.append(aux["confidence"])

    stats = {
        "loss": loss_sum / total if total > 0 else 0,
        "acc": correct / total if total > 0 else 0,
        "aug_count": aug_count,
        "avg_confidence": np.mean(confidence_vals) if confidence_vals else None
    }
    return stats

@torch.no_grad()
def evaluate(model, loader, cfg: Config, classes=6):
    model.eval()
    ys, ps = [], []
    for x, y in loader:
        x, y = x.to(cfg.device), y.to(cfg.device)
        logits, _ = model(x)
        ps.append(logits.argmax(dim=-1).cpu().numpy())
        ys.append(y.cpu().numpy())

    y_true, y_pred = np.concatenate(ys), np.concatenate(ps)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    cm = confusion_matrix(y_true, y_pred, labels=list(range(classes)))
    report = classification_report(y_true, y_pred, target_names=[_CODE_TO_LABEL_NAME[i] for i in range(classes)], digits=4)
    return acc, f1, cm, report

@torch.no_grad()
def evaluate_simple(model, loader, cfg: Config):
    """Simple evaluation for accuracy only"""
    model.eval()
    ys, ps = [], []

    for x, y in loader:
        x, y = x.to(cfg.device), y.to(cfg.device)
        logits, _ = model(x)
        pred = logits.argmax(dim=-1)

        ys.append(y.cpu().numpy())
        ps.append(pred.cpu().numpy())

    y_true = np.concatenate(ys)
    y_pred = np.concatenate(ps)
    acc = accuracy_score(y_true, y_pred)

    return {'accuracy': acc}

# ========================
# 6) Extreme Transition Test Set
# ========================
def create_transitional_test_set(orig_dataset: UCIHARInertial, class_A: str, class_B: str,
                                 p: float=0.05, mix: float=0.25, profile: str='abrupt',
                                 pos: str='tail', segments: int=1) -> Tuple[UCIHARInertial, dict]:
    """Create transitional test set"""
    X, y = orig_dataset.X.copy(), orig_dataset.y.copy()
    N, C, T = X.shape

    code_A, code_B = _LABEL_NAME_TO_CODE[class_A], _LABEL_NAME_TO_CODE[class_B]
    idx_A, idx_B = np.where(y == code_A)[0], np.where(y == code_B)[0]
    mix_pts = int(T * mix)

    modified_indices = []

    if segments > 1:
        seg_length = mix_pts // segments
        remaining = mix_pts % segments
    else:
        seg_length = mix_pts
        remaining = 0

    def get_transition_positions(T, mix_pts, pos, segments):
        positions = []

        if segments == 1:
            if pos == 'tail':
                positions = [T - mix_pts]
            elif pos == 'middle':
                positions = [(T - mix_pts) // 2]
            elif pos == 'random':
                positions = [random.randint(0, max(0, T - mix_pts))]
        else:
            seg_len = mix_pts // segments
            if pos == 'tail':
                start = T - mix_pts
                for i in range(segments):
                    positions.append(start + i * seg_len)
            elif pos == 'middle':
                center = T // 2
                total_span = mix_pts + (segments - 1) * seg_len
                start = center - total_span // 2
                for i in range(segments):
                    positions.append(start + i * (seg_len * 2))
            elif pos == 'random':
                available_positions = list(range(0, T - seg_len))
                random.shuffle(available_positions)
                positions = sorted(available_positions[:segments])

        return positions

    def apply_transition(target_data, source_data, start_pos, length, profile):
        end_pos = start_pos + length

        if profile == 'abrupt':
            target_data[:, start_pos:end_pos] = source_data[:, start_pos:end_pos].copy()
        elif profile == 'fade':
            alpha = np.linspace(0, 1, length).reshape(1, -1)
            target_segment = target_data[:, start_pos:end_pos]
            source_segment = source_data[:, start_pos:end_pos]
            target_data[:, start_pos:end_pos] = (
                target_segment * (1 - alpha) + source_segment * alpha
            )

    # Apply transitions for class A
    n_targets_A = max(1, int(len(idx_A) * p))
    targets_A = np.random.choice(idx_A, n_targets_A, replace=False)
    sources_B = np.random.choice(idx_B, len(targets_A), replace=True)

    for t, s in zip(targets_A, sources_B):
        positions = get_transition_positions(T, mix_pts, pos, segments)

        for i, start in enumerate(positions):
            curr_len = seg_length + (remaining if i == len(positions) - 1 else 0)

            if start + curr_len > T:
                curr_len = T - start

            if curr_len > 0:
                apply_transition(X[t], orig_dataset.X[s], start, curr_len, profile)

        modified_indices.append(t)

    # Apply transitions for class B
    n_targets_B = max(1, int(len(idx_B) * p))
    targets_B = np.random.choice(idx_B, n_targets_B, replace=False)
    sources_A = np.random.choice(idx_A, len(targets_B), replace=True)

    for t, s in zip(targets_B, sources_A):
        positions = get_transition_positions(T, mix_pts, pos, segments)

        for i, start in enumerate(positions):
            curr_len = seg_length + (remaining if i == len(positions) - 1 else 0)

            if start + curr_len > T:
                curr_len = T - start

            if curr_len > 0:
                apply_transition(X[t], orig_dataset.X[s], start, curr_len, profile)

        modified_indices.append(t)

    if p > 0.5:
        mid_start = T // 3
        mid_end = 2 * T // 3
        mid_length = mid_end - mid_start

        extra_A = np.random.choice(idx_A, max(1, int(len(idx_A) * p * 0.3)), replace=False)
        extra_B_src = np.random.choice(idx_B, len(extra_A), replace=True)

        for t, s in zip(extra_A, extra_B_src):
            if t not in modified_indices:
                apply_transition(X[t], orig_dataset.X[s], mid_start, mid_length, profile)
                modified_indices.append(t)

    mod_dataset = UCIHARInertial(
        root="", split="test",
        mean=orig_dataset.mean, std=orig_dataset.std,
        preloaded_data=(X, y)
    )

    info = {
        'total_samples': N,
        'modified_samples': len(modified_indices),
        'modified_ratio': len(modified_indices) / N,
        'mix_frames': mix_pts,
        'primary_class_ratio': 1 - mix,
        'class_A_modified': len(targets_A),
        'class_B_modified': len(targets_B),
        'profile': profile,
        'position': pos,
        'segments': segments
    }

    return mod_dataset, info

def get_transition_scenarios():
    """Return all transitional test scenarios"""
    scenarios_core = [
        ("STANDING","SITTING",0.70,0.55,"abrupt","tail",1),
        ("STANDING","SITTING",0.70,0.55,"fade","random",1),
        ("WALKING","WALKING_UPSTAIRS",0.70,0.55,"abrupt","tail",1),
        ("WALKING","WALKING_UPSTAIRS",0.70,0.55,"fade","random",1),
        ("SITTING","LAYING",0.70,0.55,"abrupt","tail",1),
        ("SITTING","LAYING",0.70,0.55,"fade","random",1),
        ("WALKING","WALKING_DOWNSTAIRS",0.70,0.55,"abrupt","tail",1),
        ("WALKING","WALKING_DOWNSTAIRS",0.70,0.55,"fade","random",1),
        ("STANDING","SITTING",0.70,0.55,"abrupt","tail",1),
        ("STANDING","SITTING",0.70,0.55,"fade","random",1),
    ]

    scenarios_stress = [
        ("STANDING","SITTING",0.75,0.58,"abrupt","tail",1),
        ("WALKING","WALKING_UPSTAIRS",0.75,0.58,"abrupt","tail",1),
        ("SITTING","LAYING",0.75,0.58,"abrupt","tail",1),
        ("WALKING","WALKING_DOWNSTAIRS",0.75,0.58,"abrupt","tail",1),
        ("STANDING","SITTING",0.75,0.58,"abrupt","tail",1),
    ]

    scenarios_ctrl = [
        ("WALKING","WALKING_DOWNSTAIRS",0.70,0.55,"abrupt","middle",2),
        ("SITTING","LAYING",0.70,0.55,"fade","middle",2),
    ]

    return scenarios_core + scenarios_stress + scenarios_ctrl

# ========================
# 7) GAP vs Pure TPA Comparison
# ========================
def run_ablation_study(cfg: Config):
    os.makedirs(cfg.save_dir, exist_ok=True)

    # Load full train set
    train_set_full = UCIHARInertial(cfg.data_dir, "train")

    # Split train into train + validation (stratified by class)
    print(f"\n{'='*70}")
    print("   DATA PREPARATION")
    print(f"{'='*70}")

    n_samples = len(train_set_full)
    indices = np.arange(n_samples)
    labels = train_set_full.y

    # Stratified split
    from sklearn.model_selection import train_test_split
    train_indices, val_indices = train_test_split(
        indices,
        test_size=cfg.val_split,
        random_state=SEED,
        stratify=labels
    )

    print(f"Total train samples: {n_samples}")
    print(f"  → Train split: {len(train_indices)} ({(1-cfg.val_split)*100:.0f}%)")
    print(f"  → Val split:   {len(val_indices)} ({cfg.val_split*100:.0f}%)")

    # Verify class distribution
    train_labels = labels[train_indices]
    val_labels = labels[val_indices]
    print(f"\nClass distribution check:")
    for cls in range(6):
        train_count = (train_labels == cls).sum()
        val_count = (val_labels == cls).sum()
        print(f"  Class {cls} ({_CODE_TO_LABEL_NAME[cls]}): Train={train_count}, Val={val_count}")

    # Create datasets using indices
    X_full, y_full = _load_split_raw(cfg.data_dir, "train")
    mean = X_full.mean(axis=(0,2), keepdims=True)
    std = X_full.std(axis=(0,2), keepdims=True) + 1e-6
    X_full = ((X_full - mean) / std).astype(np.float32)

    train_set = UCIHARInertial(
        cfg.data_dir, "train",
        mean=mean, std=std,
        preloaded_data=(X_full, y_full),
        indices=train_indices
    )

    val_set = UCIHARInertial(
        cfg.data_dir, "train",
        mean=mean, std=std,
        preloaded_data=(X_full, y_full),
        indices=val_indices
    )

    test_set_orig = UCIHARInertial(cfg.data_dir, "test", mean=mean, std=std)

    val_loader = DataLoader(val_set, cfg.batch_size, num_workers=cfg.num_workers)
    test_loader_orig = DataLoader(test_set_orig, cfg.batch_size, num_workers=cfg.num_workers)

    scenarios = get_transition_scenarios()

    print(f"\n{'='*70}")
    print("    GAP vs PURE TPA COMPARISON")
    print(f"{'='*70}")
    print(f"   Goal: Compare GAP and Pure TPA (no conv preprocessing)")
    print(f"\n   Training: Both configs use SAME augmentation")
    print(f"   Testing: Differ in pooling method only")
    print(f"   Scenarios: {len(scenarios)} extreme transitions (p=0.70-0.75, mix=0.55-0.58)")
    print(f"{'='*70}\n")

    random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

    # Create transitional test sets
    transition_sets, transition_infos = [], []
    for i, scenario in enumerate(scenarios):
        clsA, clsB, p, mix, profile, pos, segments = scenario
        print(f"  [{i+1}/{len(scenarios)}] {clsA} ↔ {clsB} (p={p:.2f}, mix={mix:.2f}, {profile}, {pos}, seg={segments})")
        print(f"      Primary class: {(1-mix)*100:.0f}% | Transition: {mix*100:.0f}%")
        test_set_mod, info = create_transitional_test_set(
            test_set_orig, clsA, clsB, p=p, mix=mix, profile=profile, pos=pos, segments=segments
        )
        transition_sets.append(test_set_mod)
        transition_infos.append(info)
        print(f"      Modified: {info['modified_samples']}/{info['total_samples']} ({info['modified_ratio']*100:.1f}%)")

    transition_loaders = [DataLoader(ts, cfg.batch_size, num_workers=cfg.num_workers) for ts in transition_sets]
    print(f"\n✓ {len(transition_loaders)} transitional test sets created.\n")

    # 2-way ablation configurations
    ablation_configs = [
        {"name": "GAP", "use_tpa": False},
        {"name": "Pure_TPA", "use_tpa": True},
    ]

    results_table = []

    # Train and evaluate each configuration
    for ab_cfg in ablation_configs:
        print(f"\n{'='*70}")
        print(f"   CONFIG: {ab_cfg['name']}")
        print(f"   Pooling: {ab_cfg['name']}")
        print(f"{'='*70}")

        random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

        cfg.use_tpa = ab_cfg["use_tpa"]

        g = torch.Generator(device='cpu').manual_seed(SEED)
        train_loader = DataLoader(train_set, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, generator=g)

        model_path = os.path.join(cfg.save_dir, f"model_{ab_cfg['name']}.pth")

        tpa_config = {
            'num_prototypes': cfg.tpa_num_prototypes,
            'heads': cfg.tpa_heads,
            'dropout': cfg.tpa_dropout,
            'temperature': cfg.tpa_temperature,
            'topk_ratio': cfg.tpa_topk_ratio
        }

        model = HAR_Model(
            d_model=cfg.d_model,
            use_tpa=cfg.use_tpa,
            tpa_config=tpa_config
        ).to(cfg.device).float()

        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        best_acc, best_wts = 0.0, None
        patience_counter = 0
        best_epoch = 0

        print(f"\nTraining {ab_cfg['name']} for up to {cfg.epochs} epochs (patience={cfg.patience})...")
        if cfg.use_tpa:
            print(f"Pure TPA: prototypes={cfg.tpa_num_prototypes}, heads={cfg.tpa_heads}, temp={cfg.tpa_temperature}")
            print(f"  → NO convolutional preprocessing (lowpass, depthwise, pointwise removed)")

        for epoch in range(1, cfg.epochs + 1):
            # Verbose output for first epoch and every 10th epoch
            verbose = (epoch == 1 or epoch % 10 == 0)

            stats = train_one_epoch(model, train_loader, opt, cfg, verbose_epoch=verbose)

            # Evaluate on VALIDATION set for early stopping
            val_acc, val_f1, _, _ = evaluate(model, val_loader, cfg)

            # Early stopping logic based on VALIDATION accuracy
            improved = False
            if val_acc > best_acc + cfg.min_delta:
                best_acc = val_acc
                best_wts = copy.deepcopy(model.state_dict())
                patience_counter = 0
                best_epoch = epoch
                improved = True
            else:
                patience_counter += 1

            # Only print for verbose epochs
            if verbose:
                log_str = f"[{epoch:02d}/{cfg.epochs}] Train L:{stats['loss']:.4f} A:{stats['acc']:.4f}"
                log_str += f" Aug:{stats['aug_count']}"
                log_str += f" | Val A:{val_acc:.4f} F1:{val_f1:.4f}"
                if stats['avg_confidence'] is not None:
                    log_str += f" | Conf:{stats['avg_confidence']:.3f}"
                if improved:
                    log_str += " ✓"
                print(log_str)

            # Early stopping check
            if patience_counter >= cfg.patience:
                print(f"\n⚠ Early stopping triggered at epoch {epoch}")
                print(f"  No improvement for {cfg.patience} epochs on validation set")
                print(f"  Best validation acc: {best_acc:.4f} (epoch {best_epoch})")
                break

        if best_wts:
            torch.save(best_wts, model_path)
            model.load_state_dict(best_wts)
            print(f"\n✓ Best Val Acc: {best_acc:.4f} (epoch {best_epoch})")

        # Final evaluation on TEST set
        acc_orig, f1_orig, _, _ = evaluate(model, test_loader_orig, cfg)
        print(f"  Final Test Acc: {acc_orig:.4f}, F1: {f1_orig:.4f}")

        print(f"\n   Evaluating on {len(transition_loaders)} transitional test sets...")
        transition_accs = []
        scenario_details = []

        for i, loader in enumerate(transition_loaders):
            result = evaluate_simple(model, loader, cfg)
            acc_mod = result['accuracy']
            transition_accs.append(acc_mod)

            scenario = scenarios[i]
            clsA, clsB, p, mix = scenario[0], scenario[1], scenario[2], scenario[3]
            primary_ratio = (1 - mix) * 100
            drop_from_orig = acc_orig - acc_mod

            if drop_from_orig < 0.02:
                grade = "Very Robust"
            elif drop_from_orig < 0.05:
                grade = "Slightly Vulnerable"
            else:
                grade = "Vulnerable"

            scenario_details.append({
                'scenario': f"{clsA}↔{clsB}",
                'primary_ratio': primary_ratio,
                'acc': acc_mod,
                'drop': drop_from_orig,
                'grade': grade
            })

            print(f"    Scenario {i+1}: Acc={acc_mod:.4f} Drop={drop_from_orig:.4f} [{grade}]")

        avg_trans_acc = np.mean(transition_accs)
        avg_drop = acc_orig - avg_trans_acc
        retention = (1 - avg_drop/acc_orig) * 100 if acc_orig > 0 else 0

        result = {
            "config": ab_cfg["name"],
            "use_tpa": ab_cfg["use_tpa"],
            "orig_acc": acc_orig,
            "avg_trans_acc": avg_trans_acc,
            "avg_drop": avg_drop,
            "retention": retention,
            "scenario_details": scenario_details
        }
        results_table.append(result)

        print(f"\n {ab_cfg['name']} Summary:")
        print(f"   Original Test:      {acc_orig:.4f}")
        print(f"   Avg Transition:     {avg_trans_acc:.4f}")
        print(f"   Avg Drop:           {avg_drop:.4f}")
        print(f"   Retention:          {retention:.2f}%")

    # ========================
    # Analysis and Comparison
    # ========================
    print(f"\n{'='*70}")
    print("   GAP vs PURE TPA RESULTS")
    print(f"{'='*70}\n")
    print(f"{'Config':<15} {'Pooling':<15} {'Orig':<8} {'Trans':<8} {'Drop':<8} {'Retention':<10}")
    print("-" * 70)

    for r in results_table:
        pooling = "Pure TPA" if r['use_tpa'] else "GAP"
        print(f"{r['config']:<15} {pooling:<15} {r['orig_acc']:<8.4f} {r['avg_trans_acc']:<8.4f} {r['avg_drop']:<8.4f} {r['retention']:<10.2f}%")

    # Compute effects
    print(f"\n{'='*70}")
    print("   EFFECT ANALYSIS")
    print(f"{'='*70}\n")

    # Find each config
    gap = next(r for r in results_table if not r['use_tpa'])
    tpa = next(r for r in results_table if r['use_tpa'])

    # Pure TPA effect
    tpa_effect = gap['avg_drop'] - tpa['avg_drop']
    tpa_improve = (tpa_effect / gap['avg_drop'] * 100) if gap['avg_drop'] > 0 else 0

    print(f"PURE TPA EFFECT (vs GAP):")
    print(f"   GAP  →  Pure TPA")
    print(f"   Drop: {gap['avg_drop']:.4f} → {tpa['avg_drop']:.4f}")
    print(f"   Absolute improvement: {tpa_effect:+.4f}")
    print(f"   Relative improvement: {tpa_improve:+.2f}% drop reduction")
    print(f"   Retention gain: {tpa['retention'] - gap['retention']:+.2f}pp")

    if tpa_effect > 0:
        print(f"\n   ✓ Pure TPA helps by reducing performance drop on transitional data")
        print(f"     This shows that prototype-based attention alone (without conv)")
        print(f"     provides robustness benefits over simple global average pooling")
    elif tpa_effect < -0.01:
        print(f"\n   ✗ Pure TPA performs worse than GAP")
        print(f"     The conv preprocessing layers may be important for TPA performance")
    else:
        print(f"\n   ≈ Pure TPA and GAP show similar performance")
        print(f"     Prototype attention alone may not provide significant advantage")

    # Original accuracy comparison
    orig_diff = tpa['orig_acc'] - gap['orig_acc']
    print(f"\n\nORIGINAL TEST ACCURACY:")
    print(f"   GAP:      {gap['orig_acc']:.4f}")
    print(f"   Pure TPA: {tpa['orig_acc']:.4f}")
    print(f"   Difference: {orig_diff:+.4f}")

    if abs(orig_diff) < 0.01:
        print(f"   → Similar baseline performance")
    elif orig_diff > 0:
        print(f"   → Pure TPA has higher baseline accuracy")
    else:
        print(f"   → GAP has higher baseline accuracy")

    # Save results
    with open(os.path.join(cfg.save_dir, "pure_tpa_results.json"), "w") as f:
        json.dump({
            'results': results_table,
            'analysis': {
                'tpa_effect': float(tpa_effect),
                'tpa_improve_pct': float(tpa_improve),
                'orig_acc_diff': float(orig_diff)
            }
        }, f, indent=2)

    print(f"\n✓ Results saved to '{cfg.save_dir}/pure_tpa_results.json'")
    print(f"{'='*70}\n")

# ========================
# 8) Main Execution
# ========================
if __name__ == "__main__":
    config = Config()
    config.mode = "ablation"
    config.epochs = 100
    config.lr = 1e-4

    config.train_augment_prob = 0.25
    config.train_augment_mix = 0.35

    # Early stopping
    config.patience = 20
    config.min_delta = 0.0001
    config.val_split = 0.2

    # Pure TPA hyperparameters (no seg_kernel needed)
    config.tpa_num_prototypes = 16
    config.tpa_heads = 4
    config.tpa_dropout = 0.1
    config.tpa_temperature = 0.07
    config.tpa_topk_ratio = 0.25

    print("\n" + "="*70)
    print("    UCI-HAR GAP vs PURE TPA COMPARISON")
    print("="*70)
    print(f"Device: {config.device}")
    print(f"Epochs: {config.epochs}")
    print(f"Learning Rate: {config.lr}")
    print(f"Train Augmentation: prob={config.train_augment_prob}, mix={config.train_augment_mix}")
    print(f"\n2 Configurations to Compare:")
    print(f"  1) GAP:      Global Average Pooling (baseline)")
    print(f"  2) Pure TPA: Temporal Prototype Attention (NO conv preprocessing)")
    print(f"\nThis allows us to measure:")
    print(f"  • Effect of pure prototype attention vs GAP")
    print(f"  • Whether TPA alone (without conv layers) provides robustness")
    print(f"  • Robustness to transitional noise")
    print(f"\nPure TPA Configuration:")
    print(f"  Prototypes: {config.tpa_num_prototypes}")
    print(f"  Heads: {config.tpa_heads}")
    print(f"  Temperature: {config.tpa_temperature}")
    print(f"  TopK Ratio: {config.tpa_topk_ratio}")
    print(f"  ⚠ Conv layers REMOVED: No lowpass, depthwise, or pointwise conv")
    print("="*70 + "\n")

    if config.mode == "ablation":
        run_ablation_study(config)
    else:
        print("✗ Invalid mode. Set config.mode = 'ablation'")


    UCI-HAR GAP vs PURE TPA COMPARISON
Device: cuda
Epochs: 100
Learning Rate: 0.0001
Train Augmentation: prob=0.25, mix=0.35

2 Configurations to Compare:
  1) GAP:      Global Average Pooling (baseline)
  2) Pure TPA: Temporal Prototype Attention (NO conv preprocessing)

This allows us to measure:
  • Effect of pure prototype attention vs GAP
  • Whether TPA alone (without conv layers) provides robustness
  • Robustness to transitional noise

Pure TPA Configuration:
  Prototypes: 16
  Heads: 4
  Temperature: 0.07
  TopK Ratio: 0.25
  ⚠ Conv layers REMOVED: No lowpass, depthwise, or pointwise conv


   DATA PREPARATION
Total train samples: 7352
  → Train split: 5881 (80%)
  → Val split:   1471 (20%)

Class distribution check:
  Class 0 (WALKING): Train=981, Val=245
  Class 1 (WALKING_UPSTAIRS): Train=858, Val=215
  Class 2 (WALKING_DOWNSTAIRS): Train=789, Val=197
  Class 3 (SITTING): Train=1029, Val=257
  Class 4 (STANDING): Train=1099, Val=275
  Class 5 (LAYING): Train=1125, Val=282