In [None]:
# -*- coding: utf-8 -*-
"""
Simplified Model Comparison: GAP, TPA, Gated-TPA
- TPA Top-k 마스킹 적용
- 프로토타입 다양성 페널티
- 로짓 수준 MoE 융합 (별도 분류기)
"""

from google.colab import drive
drive.mount('/content/drive')

import os, random, time, copy, json
import numpy as np
from typing import Tuple, Dict, List
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

# ========================
# Config & Reproducibility
# ========================
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

@dataclass
class Config:
    data_dir: str = "/content/drive/MyDrive/AI_data/TPA2/pamap2_transition_datasets"
    save_dir: str = "/content/drive/MyDrive/AI_data/TPA2"

    epochs: int = 100
    batch_size: int = 128
    lr: float = 1e-4
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    label_smoothing: float = 0.05

    patience: int = 20
    min_delta: float = 0.0001
    val_split: float = 0.2

    d_model: int = 128

    # Transformer hyperparameters
    num_layers: int = 2
    n_heads: int = 4
    ff_dim: int = 256
    dropout: float = 0.1

    # TPA hyperparameters
    tpa_num_prototypes: int = 16
    tpa_heads: int = 4
    tpa_dropout: float = 0.1
    tpa_temperature: float = 0.07
    tpa_topk_ratio: float = 0.25

    # 새로운 하이퍼파라미터
    diversity_weight: float = 5e-3  # 프로토타입 다양성 페널티 가중치
    use_logit_fusion: bool = True    # 로짓 수준 융합 사용 여부

    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2

cfg = Config()

# ========================
# Dataset Class
# ========================
class PreloadedDataset(Dataset):
    """Dataset for pre-loaded numpy arrays"""
    def __init__(self, X: np.ndarray, y: np.ndarray):
        super().__init__()
        self.X = torch.from_numpy(X).float()

        # Label 범위 확인 및 조정 (1-6 -> 0-5)
        if y.min() >= 1:
            y = y - 1

        self.y = torch.from_numpy(y).long()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ========================
# Data Loading Functions
# ========================
def load_dataset(base_dir: str, dataset_name: str):
    """
    Load pre-augmented dataset
    Args:
        base_dir: base directory containing all datasets
        dataset_name: e.g., "ORIGINAL", "STANDING_TO_SITTING_10pct", etc.
    Returns:
        train_dataset, test_dataset
    """
    dataset_dir = os.path.join(base_dir, dataset_name)

    print(f"\nLoading {dataset_name}...")
    print(f"  Path: {dataset_dir}")

    # Load data
    X_train = np.load(os.path.join(dataset_dir, "X_train.npy"))
    y_train = np.load(os.path.join(dataset_dir, "y_train.npy"))
    X_test = np.load(os.path.join(dataset_dir, "X_test.npy"))
    y_test = np.load(os.path.join(dataset_dir, "y_test.npy"))

    print(f"  Train: {X_train.shape}, Test: {X_test.shape}")

    train_dataset = PreloadedDataset(X_train, y_train)
    test_dataset = PreloadedDataset(X_test, y_test)

    return train_dataset, test_dataset

# ========================
# Transformer Backbone Components
# ========================
class PositionalEncoding(nn.Module):
    """Sinusoidal Positional Encoding"""
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [B, T, D]
        Returns:
            [B, T, D]
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class TransformerBackbone(nn.Module):
    """
    Lightweight Transformer Encoder Backbone
    - 2 layers
    - d_model=128
    - n_heads=4
    - ff_dim=256
    - Dropout=0.1
    """
    def __init__(self,
                 in_channels: int = 27,
                 d_model: int = 128,
                 num_layers: int = 2,
                 n_heads: int = 4,
                 ff_dim: int = 256,
                 dropout: float = 0.1,
                 max_seq_len: int = 200):
        super().__init__()

        self.d_model = d_model

        # Input projection: [B, C, T] -> [B, T, D]
        self.input_projection = nn.Linear(in_channels, d_model)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len, dropout)

        # Transformer Encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-LN for better stability
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # Output normalization
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        """
        Args:
            x: [B, C, T] - input sensor data
        Returns:
            [B, T, D] - transformed sequence
        """
        # [B, C, T] -> [B, T, C]
        # x = x.transpose(1, 2)

        # Project to d_model: [B, T, C] -> [B, T, D]
        x = self.input_projection(x)

        # Add positional encoding: [B, T, D]
        x = self.pos_encoder(x)

        # Transformer encoding: [B, T, D]
        x = self.transformer_encoder(x)

        # Final normalization: [B, T, D]
        x = self.norm(x)

        return x

# ========================
# GAP Model
# ========================
class GAPModel(nn.Module):
    """Baseline: Global Average Pooling with Transformer Backbone"""
    def __init__(self,
                 in_channels: int = 27,
                 d_model: int = 128,
                 num_layers: int = 2,
                 n_heads: int = 4,
                 ff_dim: int = 256,
                 dropout: float = 0.1,
                 num_classes: int = 12):
        super().__init__()
        self.backbone = TransformerBackbone(
            in_channels=in_channels,
            d_model=d_model,
            num_layers=num_layers,
            n_heads=n_heads,
            ff_dim=ff_dim,
            dropout=dropout
        )
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        features = self.backbone(x)  # [B, T, D]
        pooled = features.mean(dim=1)  # [B, D]
        logits = self.fc(pooled)
        return logits

# ========================
# Improved TPA with Top-k
# ========================
class ImprovedTPA(nn.Module):
    """개선된 TPA: Top-k 마스킹 + 다양성 정규화"""
    def __init__(self, dim, num_prototypes=16, heads=4, dropout=0.1,
                 temperature=0.07, topk_ratio=0.25):
        super().__init__()
        assert dim % heads == 0

        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.num_prototypes = num_prototypes
        self.temperature = temperature
        self.topk_ratio = topk_ratio

        self.proto = nn.Parameter(torch.randn(num_prototypes, dim) * 0.02)

        self.pre_norm = nn.LayerNorm(dim)

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)

        self.fuse = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [B, T, D]
        Returns:
            z_tpa: [B, D]
        """
        B, T, D = x.shape
        P = self.num_prototypes

        x_norm = self.pre_norm(x)

        K = self.k_proj(x_norm)
        V = self.v_proj(x_norm)
        Qp = self.q_proj(self.proto).unsqueeze(0).expand(B, -1, -1)

        def split_heads(t, length):
            return t.view(B, length, self.heads, self.head_dim).transpose(1, 2)

        Qh = split_heads(Qp, P)  # [B, H, P, d]
        Kh = split_heads(K, T)    # [B, H, T, d]
        Vh = split_heads(V, T)    # [B, H, T, d]

        # Qh = F.normalize(Qh, dim=-1)
        # Kh = F.normalize(Kh, dim=-1)

        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)  # [B, H, P, T]
        attn = torch.nan_to_num(attn, nan=0.0)

        # ==================
        # Top-k 마스킹
        # ==================
        k = max(1, int(self.topk_ratio * T))
        vals, idx = attn.topk(k, dim=-1)  # [B, H, P, k]
        mask = torch.zeros_like(attn).scatter_(-1, idx, 1.0)
        attn = attn * mask
        # 재정규화
        attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)

        attn = self.dropout(attn)

        proto_tokens = torch.matmul(attn, Vh)  # [B, H, P, d]
        proto_tokens = proto_tokens.transpose(1, 2).contiguous().view(B, P, D)

        z_tpa = proto_tokens.mean(dim=1)  # [B, D]

        z = self.fuse(z_tpa)

        return z

    def compute_diversity_loss(self):
        """
        프로토타입 다양성 페널티
        Returns:
            diversity_loss: scalar
        """
        proto_norm = F.normalize(self.proto, dim=-1)  # [P, D]
        sim = proto_norm @ proto_norm.t()  # [P, P]
        # 대각선 제외하고 유사도를 최소화
        div_loss = (sim - torch.eye(sim.size(0), device=sim.device)).pow(2).mean()
        return div_loss

class TPAModel(nn.Module):
    """개선된 TPA 모델"""
    def __init__(self,
                 in_channels: int = 27,
                 d_model: int = 128,
                 num_layers: int = 2,
                 n_heads: int = 4,
                 ff_dim: int = 256,
                 dropout: float = 0.1,
                 num_classes: int = 12,
                 tpa_config=None):
        super().__init__()
        self.backbone = TransformerBackbone(
            in_channels=in_channels,
            d_model=d_model,
            num_layers=num_layers,
            n_heads=n_heads,
            ff_dim=ff_dim,
            dropout=dropout
        )

        self.tpa = ImprovedTPA(
            dim=d_model,
            num_prototypes=tpa_config['num_prototypes'],
            heads=tpa_config['heads'],
            dropout=tpa_config['dropout'],
            temperature=tpa_config['temperature'],
            topk_ratio=tpa_config['topk_ratio']
        )

        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        features = self.backbone(x)  # [B, T, D]
        z = self.tpa(features)  # [B, D]
        logits = self.classifier(z)
        return logits

# ========================
# Improved Gated-TPA with Logit-level Fusion
# ========================
class ImprovedGatedTPAModel(nn.Module):
    """개선된 Gated-TPA: 로짓 수준 융합 + 별도 분류기"""
    def __init__(self,
                 in_channels: int = 27,
                 d_model: int = 128,
                 num_layers: int = 2,
                 n_heads: int = 4,
                 ff_dim: int = 256,
                 dropout: float = 0.1,
                 num_classes: int = 12,
                 tpa_config=None,
                 use_logit_fusion=True):
        super().__init__()
        self.use_logit_fusion = use_logit_fusion
        self.backbone = TransformerBackbone(
            in_channels=in_channels,
            d_model=d_model,
            num_layers=num_layers,
            n_heads=n_heads,
            ff_dim=ff_dim,
            dropout=dropout
        )

        self.tpa = ImprovedTPA(
            dim=d_model,
            num_prototypes=tpa_config['num_prototypes'],
            heads=tpa_config['heads'],
            dropout=tpa_config['dropout'],
            temperature=tpa_config['temperature'],
            topk_ratio=tpa_config['topk_ratio']
        )

        if use_logit_fusion:
            # 로짓 수준 융합 (별도 분류기)
            self.cls_gap = nn.Linear(d_model, num_classes)
            self.cls_tpa = nn.Linear(d_model, num_classes)

            # Gating mechanism
            self.gate = nn.Sequential(
                nn.Linear(d_model * 2, num_classes),
                nn.Sigmoid()
            )
        else:
            # 기존 feature 수준 융합
            self.gate = nn.Sequential(
                nn.Linear(d_model * 2, d_model),
                nn.Sigmoid()
            )
            self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        features = self.backbone(x)  # [B, T, D]

        # GAP branch
        z_gap = features.mean(dim=1)  # [B, D]

        # TPA branch
        z_tpa = self.tpa(features)  # [B, D]

        if self.use_logit_fusion:
            # 로짓 수준 융합
            logits_gap = self.cls_gap(z_gap)  # [B, C]
            logits_tpa = self.cls_tpa(z_tpa)  # [B, C]

            # 게이팅 (클래스별 게이트)
            gate_input = torch.cat([z_gap, z_tpa], dim=-1)
            g = self.gate(gate_input)  # [B, C]

            # 가중 융합
            logits = g * logits_gap + (1 - g) * logits_tpa
        else:
            # 기존 feature 수준 융합
            gate_input = torch.cat([z_gap, z_tpa], dim=-1)
            g = self.gate(gate_input)  # [B, D]
            z = g * z_gap + (1 - g) * z_tpa
            logits = self.classifier(z)

        return logits

# ========================
# Training & Evaluation
# ========================
def train_one_epoch(model, loader, opt, cfg: Config, compute_diversity=True):
    """
    개선: 다양성 페널티 추가
    """
    model.train()
    total, correct, loss_sum, ce_loss_sum, div_loss_sum = 0, 0, 0.0, 0.0, 0.0

    for x, y in loader:
        x, y = x.to(cfg.device).float(), y.to(cfg.device)

        opt.zero_grad(set_to_none=True)
        logits = model(x)

        # Cross-entropy loss
        ce_loss = F.cross_entropy(logits, y, label_smoothing=cfg.label_smoothing)

        # Diversity loss (TPA 모델만)
        div_loss = torch.tensor(0.0, device=cfg.device)
        if compute_diversity and hasattr(model, 'tpa'):
            div_loss = model.tpa.compute_diversity_loss()

        # Total loss
        loss = ce_loss + cfg.diversity_weight * div_loss

        if torch.isnan(loss):
            continue

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()

        with torch.no_grad():
            pred = logits.argmax(dim=-1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)
            ce_loss_sum += ce_loss.item() * y.size(0)
            div_loss_sum += div_loss.item() * y.size(0)

    return {
        "loss": loss_sum / total if total > 0 else 0,
        "ce_loss": ce_loss_sum / total if total > 0 else 0,
        "div_loss": div_loss_sum / total if total > 0 else 0,
        "acc": correct / total if total > 0 else 0
    }

@torch.no_grad()
def evaluate(model, loader, cfg: Config):
    model.eval()
    ys, ps = [], []

    for x, y in loader:
        x, y = x.to(cfg.device), y.to(cfg.device)
        logits = model(x)
        ps.append(logits.argmax(dim=-1).cpu().numpy())
        ys.append(y.cpu().numpy())

    y_true, y_pred = np.concatenate(ys), np.concatenate(ps)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')

    return acc, f1

# ========================
# Model Complexity Analysis
# ========================
def count_parameters(model):
    """Count total and trainable parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def estimate_flops(model, input_shape=(1, 100, 27), device='cuda'):
    """
    Estimate FLOPs using manual calculation
    For Conv1d: FLOPs = 2 * C_in * C_out * K * L_out
    For Linear: FLOPs = 2 * in_features * out_features
    """
    model.eval()
    total_flops = 0

    def conv1d_flops(module, input, output):
        batch_size, out_channels, out_length = output.shape
        kernel_size = module.kernel_size[0]
        in_channels = module.in_channels
        groups = module.groups

        flops_per_element = 2 * (in_channels // groups) * kernel_size
        total = flops_per_element * out_channels * out_length * batch_size

        nonlocal total_flops
        total_flops += total

    def linear_flops(module, input, output):
        batch_size = input[0].shape[0]
        in_features = module.in_features
        out_features = module.out_features

        total = 2 * in_features * out_features * batch_size

        nonlocal total_flops
        total_flops += total

    # Register hooks
    hooks = []
    for module in model.modules():
        if isinstance(module, nn.Conv1d):
            hooks.append(module.register_forward_hook(conv1d_flops))
        elif isinstance(module, nn.Linear):
            hooks.append(module.register_forward_hook(linear_flops))

    # Forward pass
    with torch.no_grad():
        x = torch.randn(input_shape).to(device)
        model(x)

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return total_flops

def measure_inference_time(model, input_shape=(1, 100, 27), device='cuda', n_runs=100):
    """
    Measure average inference time over multiple runs
    """
    model.eval()
    x = torch.randn(input_shape).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(x)

    # Synchronize GPU
    if device == 'cuda':
        torch.cuda.synchronize()

    # Measure
    start = time.time()
    with torch.no_grad():
        for _ in range(n_runs):
            _ = model(x)
            if device == 'cuda':
                torch.cuda.synchronize()
    end = time.time()

    avg_time = (end - start) / n_runs * 1000  # Convert to ms
    return avg_time

def analyze_model_complexity(model, model_name, cfg: Config, input_shape=(1, 100, 27)):
    """
    Complete model complexity analysis
    """
    print(f"\n{'='*80}")
    print(f"MODEL COMPLEXITY ANALYSIS: {model_name}")
    print(f"{'='*80}")

    # Parameters
    total_params, trainable_params = count_parameters(model)
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Model Size: {total_params * 4 / 1024 / 1024:.2f} MB (float32)")

    # FLOPs
    flops = estimate_flops(model, input_shape, cfg.device)
    print(f"FLOPs: {flops:,} ({flops / 1e6:.2f} MFLOPs)")

    # Inference time
    inference_time = measure_inference_time(model, input_shape, cfg.device, n_runs=100)
    print(f"Inference Time: {inference_time:.3f} ms (avg over 100 runs)")

    return {
        'model': model_name,
        'total_params': total_params,
        'trainable_params': trainable_params,
        'model_size_mb': total_params * 4 / 1024 / 1024,
        'flops': flops,
        'mflops': flops / 1e6,
        'inference_time_ms': inference_time
    }

def train_model(model, train_loader, val_loader, cfg: Config, model_name: str):
    """Train a single model"""
    print(f"\n[Training {model_name}]")

    # TPA 모델만 diversity loss 계산
    compute_diversity = 'TPA' in model_name

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    best_acc, best_wts = 0.0, None
    patience_counter = 0

    for epoch in range(1, cfg.epochs + 1):
        stats = train_one_epoch(model, train_loader, opt, cfg, compute_diversity)
        val_acc, val_f1 = evaluate(model, val_loader, cfg)

        if val_acc > best_acc + cfg.min_delta:
            best_acc = val_acc
            best_wts = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1

        if epoch % 10 == 0:
            if compute_diversity:
                print(f"  Epoch {epoch:3d}: Train Acc={stats['acc']:.4f}, "
                      f"CE={stats['ce_loss']:.4f}, Div={stats['div_loss']:.4f}, "
                      f"Val Acc={val_acc:.4f}, F1={val_f1:.4f}")
            else:
                print(f"  Epoch {epoch:3d}: Train Acc={stats['acc']:.4f}, "
                      f"Loss={stats['loss']:.4f}, Val Acc={val_acc:.4f}, F1={val_f1:.4f}")

        if patience_counter >= cfg.patience:
            print(f"  Early stopping at epoch {epoch}")
            break

    if best_wts:
        model.load_state_dict(best_wts)

    print(f"  Best Val Acc: {best_acc:.4f}")
    return best_acc

def create_model(model_name: str, cfg: Config):
    """Create model by name"""
    tpa_config = {
        'num_prototypes': cfg.tpa_num_prototypes,
        'heads': cfg.tpa_heads,
        'dropout': cfg.tpa_dropout,
        'temperature': cfg.tpa_temperature,
        'topk_ratio': cfg.tpa_topk_ratio
    }

    if model_name == "GAP":
        return GAPModel(d_model=cfg.d_model).to(cfg.device).float()
    elif model_name == "TPA":
        return TPAModel(
            d_model=cfg.d_model,
            tpa_config=tpa_config
        ).to(cfg.device).float()
    elif model_name == "Gated-TPA":
        return ImprovedGatedTPAModel(
            d_model=cfg.d_model,
            tpa_config=tpa_config,
            use_logit_fusion=cfg.use_logit_fusion
        ).to(cfg.device).float()
    else:
        raise ValueError(f"Unknown model: {model_name}")

# ========================
# Main Experiment
# ========================
def run_experiment(dataset_name: str, cfg: Config):
    """Run complete experiment for one dataset"""

    print(f"\n{'='*80}")
    print(f"EXPERIMENT: {dataset_name}")
    print(f"{'='*80}")

    # Load data
    train_dataset, test_dataset = load_dataset(cfg.data_dir, dataset_name)

    # Split train into train/val using indices
    n_total = len(train_dataset)
    indices = np.arange(n_total)

    # Get labels for stratification
    y_labels = train_dataset.y.numpy()

    train_indices, val_indices = train_test_split(
        indices,
        test_size=cfg.val_split,
        random_state=SEED,
        stratify=y_labels
    )

    # Create subsets using Subset
    from torch.utils.data import Subset
    train_subset = Subset(train_dataset, train_indices)
    val_subset = Subset(train_dataset, val_indices)

    # Create data loaders
    g = torch.Generator(device='cpu').manual_seed(SEED)
    train_loader = DataLoader(train_subset, cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, generator=g)
    val_loader = DataLoader(val_subset, cfg.batch_size, num_workers=cfg.num_workers)
    test_loader = DataLoader(test_dataset, cfg.batch_size, num_workers=cfg.num_workers)

    print(f"\nDataset splits:")
    print(f"  Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(test_dataset)}")

    # Train and evaluate all models
    results = []
    complexity_results = []
    model_names = ["GAP", "TPA", "Gated-TPA"]

    # First, analyze model complexity (only once, use first dataset)
    if dataset_name == "ORIGINAL":
        print(f"\n{'='*80}")
        print("MODEL COMPLEXITY COMPARISON")
        print(f"{'='*80}")

        for model_name in model_names:
            random.seed(SEED)
            np.random.seed(SEED)
            torch.manual_seed(SEED)

            model = create_model(model_name, cfg)
            complexity = analyze_model_complexity(model, model_name, cfg)
            complexity_results.append(complexity)

    for model_name in model_names:
        # Reset seed for each model
        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)

        # Create and train model
        model = create_model(model_name, cfg)
        best_val_acc = train_model(model, train_loader, val_loader, cfg, model_name)

        # Evaluate on test set
        test_acc, test_f1 = evaluate(model, test_loader, cfg)

        print(f"\n[{model_name} Results]")
        print(f"  Val Acc: {best_val_acc:.4f}")
        print(f"  Test Acc: {test_acc:.4f}, F1: {test_f1:.4f}")

        results.append({
            'Model': model_name,
            'Dataset': dataset_name,
            'Val_Accuracy': float(best_val_acc),
            'Test_Accuracy': float(test_acc),
            'Test_F1_Score': float(test_f1)
        })

    return results, complexity_results

# ========================
# Run All Experiments
# ========================
if __name__ == "__main__":
    print("\n" + "="*80)
    print("SIMPLIFIED MODEL COMPARISON: GAP vs TPA vs Gated-TPA")
    print("="*80)
    print("\n개선사항:")
    print("  1. TPA Top-k 마스킹 적용")
    print("  2. 프로토타입 다양성 페널티")
    print("  3. 로짓 수준 MoE 융합 (별도 분류기)")
    print("="*80)

    datasets = ["ORIGINAL"]

    transitions = [
        'Standing_TO_Sitting',
        'Sitting_TO_Standing',
        'Sitting_TO_Lying',
        'Lying_TO_Sitting',
        'Standing_TO_Lying',
        'Lying_TO_Standing',
        'Standing_TO_Walking',
        'Walking_TO_Standing',
        'Walking_TO_Running',
        'Running_TO_Walking',
        'Walking_TO_Ascending_stairs',
        'Walking_TO_Descending_stairs',
        'Ascending_stairs_TO_Walking',
        'Descending_stairs_TO_Walking'
    ]

    # 모든 전이에 대해 10%, 20%, 30%, 40% 추가
    mix_pcts = [10, 20, 30, 40]

    for transition in transitions:
        for pct in mix_pcts:
            datasets.append(f"{transition}_{pct}PCT")

    print(f"\nTotal datasets to test: {len(datasets)}")
    print(f"  - transitions: {len(transitions) * len(mix_pcts) + 2}")

    all_results = []
    all_complexity = []

    # Run experiments
    for i, dataset_name in enumerate(datasets, 1):
        print(f"\n[Progress: {i}/{len(datasets)}]")

        results, complexity = run_experiment(dataset_name, cfg)
        all_results.extend(results)
        if complexity:  # 첫 번째 데이터셋에서만 반환됨
            all_complexity = complexity

    # Save all results
    print(f"\n{'='*80}")
    print("SAVING RESULTS")
    print(f"{'='*80}")

    results_dict = {
        'experiment_info': {
            'date': time.strftime('%Y-%m-%d %H:%M:%S'),
            'version': 'simplified_v1',
            'improvements': [
                'TPA Top-k masking',
                'Prototype diversity penalty',
                'Logit-level MoE fusion'
            ],
            'models': ['GAP', 'TPA', 'Gated-TPA'],
            'total_datasets': len(datasets),
            'datasets': datasets,
            'config': {
                'epochs': cfg.epochs,
                'batch_size': cfg.batch_size,
                'lr': cfg.lr,
                'd_model': cfg.d_model,
                'tpa_num_prototypes': cfg.tpa_num_prototypes,
                'tpa_heads': cfg.tpa_heads,
                'tpa_temperature': cfg.tpa_temperature,
                'tpa_topk_ratio': cfg.tpa_topk_ratio,
                'diversity_weight': cfg.diversity_weight,
                'use_logit_fusion': cfg.use_logit_fusion
            }
        },
        'model_complexity': all_complexity,
        'results': all_results
    }

    # Save to JSON
    json_path = os.path.join(cfg.save_dir, "pamap2_tpa_transition_cnn_simplified.json")
    with open(json_path, 'w') as f:
        json.dump(results_dict, f, indent=2)

    print(f"\nResults saved to: {json_path}")

    # Print summary
    print(f"\n{'='*80}")
    print("SUMMARY")
    print(f"{'='*80}")
    print(f"Total experiments: {len(all_results)}")
    print(f"Total datasets tested: {len(datasets)}")
    print(f"Models compared: 3 (GAP, TPA, Gated-TPA)")

    # Calculate average performance per model
    print(f"\n{'='*80}")
    print("AVERAGE PERFORMANCE (All Datasets)")
    print(f"{'='*80}")

    for model_name in ['GAP', 'TPA', 'Gated-TPA']:
        model_results = [r for r in all_results if r['Model'] == model_name]
        avg_acc = np.mean([r['Test_Accuracy'] for r in model_results])
        avg_f1 = np.mean([r['Test_F1_Score'] for r in model_results])
        print(f"{model_name:12s}: Acc={avg_acc:.4f}, F1={avg_f1:.4f}")

    # Print model complexity table
    if all_complexity:
        print(f"\n{'='*80}")
        print("MODEL COMPLEXITY COMPARISON")
        print(f"{'='*80}")
        print(f"{'Model':<12} {'Params':<12} {'Size(MB)':<10} {'MFLOPs':<10} {'Time(ms)':<10}")
        print("-" * 80)
        for comp in all_complexity:
            print(f"{comp['model']:<12} {comp['total_params']:<12,} "
                  f"{comp['model_size_mb']:<10.2f} {comp['mflops']:<10.2f} "
                  f"{comp['inference_time_ms']:<10.3f}")

    print(f"\n{'='*80}")
    print("EXPERIMENT COMPLETE")
    print(f"{'='*80}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

SIMPLIFIED MODEL COMPARISON: GAP vs TPA vs Gated-TPA

개선사항:
  1. TPA Top-k 마스킹 적용
  2. 프로토타입 다양성 페널티
  3. 로짓 수준 MoE 융합 (별도 분류기)

Total datasets to test: 57
  - transitions: 58

[Progress: 1/57]

EXPERIMENT: ORIGINAL

Loading ORIGINAL...
  Path: /content/drive/MyDrive/AI_data/TPA2/pamap2_transition_datasets/ORIGINAL
  Train: (31084, 100, 27), Test: (7772, 100, 27)

Dataset splits:
  Train: 24867, Val: 6217, Test: 7772

MODEL COMPLEXITY COMPARISON

MODEL COMPLEXITY ANALYSIS: GAP
Total Parameters: 270,348
Trainable Parameters: 270,348
Model Size: 1.03 MB (float32)
FLOPs: 272,128 (0.27 MFLOPs)
Inference Time: 0.885 ms (avg over 100 runs)

MODEL COMPLEXITY ANALYSIS: TPA
Total Parameters: 354,828
Trainable Parameters: 354,828
Model Size: 1.35 MB (float32)




FLOPs: 927,488 (0.93 MFLOPs)
Inference Time: 2.187 ms (avg over 100 runs)

MODEL COMPLEXITY ANALYSIS: Gated-TPA
Total Parameters: 359,460
Trainable Parameters: 359,460
Model Size: 1.37 MB (float32)
FLOPs: 936,704 (0.94 MFLOPs)
Inference Time: 2.093 ms (avg over 100 runs)

[Training GAP]
  Epoch  10: Train Acc=0.9219, Loss=0.5723, Val Acc=0.9213, F1=0.9125
  Epoch  20: Train Acc=0.9579, Loss=0.4494, Val Acc=0.9472, F1=0.9408
  Epoch  30: Train Acc=0.9745, Loss=0.4010, Val Acc=0.9590, F1=0.9540
  Epoch  40: Train Acc=0.9848, Loss=0.3714, Val Acc=0.9633, F1=0.9596
  Epoch  50: Train Acc=0.9901, Loss=0.3530, Val Acc=0.9691, F1=0.9657
  Epoch  60: Train Acc=0.9926, Loss=0.3414, Val Acc=0.9743, F1=0.9717
  Epoch  70: Train Acc=0.9945, Loss=0.3324, Val Acc=0.9768, F1=0.9741
  Epoch  80: Train Acc=0.9955, Loss=0.3263, Val Acc=0.9796, F1=0.9774
  Epoch  90: Train Acc=0.9962, Loss=0.3222, Val Acc=0.9802, F1=0.9774
  Epoch 100: Train Acc=0.9967, Loss=0.3186, Val Acc=0.9807, F1=0.9787
  Best Val A



  Epoch  10: Train Acc=0.9254, CE=0.5463, Div=0.0684, Val Acc=0.9178, F1=0.9069
  Epoch  20: Train Acc=0.9580, CE=0.4487, Div=0.0168, Val Acc=0.9447, F1=0.9378
  Epoch  30: Train Acc=0.9719, CE=0.4072, Div=0.0081, Val Acc=0.9545, F1=0.9481
  Epoch  40: Train Acc=0.9828, CE=0.3793, Div=0.0053, Val Acc=0.9664, F1=0.9623
  Epoch  50: Train Acc=0.9888, CE=0.3609, Div=0.0027, Val Acc=0.9699, F1=0.9672
  Epoch  60: Train Acc=0.9923, CE=0.3479, Div=0.0018, Val Acc=0.9736, F1=0.9707
  Epoch  70: Train Acc=0.9939, CE=0.3395, Div=0.0012, Val Acc=0.9749, F1=0.9731
  Epoch  80: Train Acc=0.9953, CE=0.3333, Div=0.0005, Val Acc=0.9751, F1=0.9728
  Epoch  90: Train Acc=0.9960, CE=0.3296, Div=0.0003, Val Acc=0.9788, F1=0.9779
  Epoch 100: Train Acc=0.9961, CE=0.3256, Div=0.0002, Val Acc=0.9762, F1=0.9747
  Best Val Acc: 0.9797

[TPA Results]
  Val Acc: 0.9797
  Test Acc: 0.9811, F1: 0.9796

[Training Gated-TPA]




  Epoch  10: Train Acc=0.9241, CE=0.5405, Div=0.0434, Val Acc=0.9207, F1=0.9122
  Epoch  20: Train Acc=0.9550, CE=0.4414, Div=0.0083, Val Acc=0.9474, F1=0.9406
  Epoch  30: Train Acc=0.9745, CE=0.3936, Div=0.0029, Val Acc=0.9543, F1=0.9490
  Epoch  40: Train Acc=0.9847, CE=0.3653, Div=0.0017, Val Acc=0.9665, F1=0.9627
  Epoch  50: Train Acc=0.9902, CE=0.3482, Div=0.0007, Val Acc=0.9693, F1=0.9661
  Epoch  60: Train Acc=0.9934, CE=0.3362, Div=0.0008, Val Acc=0.9715, F1=0.9685
  Epoch  70: Train Acc=0.9943, CE=0.3291, Div=0.0003, Val Acc=0.9722, F1=0.9694
  Epoch  80: Train Acc=0.9950, CE=0.3233, Div=0.0002, Val Acc=0.9746, F1=0.9727
  Epoch  90: Train Acc=0.9961, CE=0.3185, Div=0.0001, Val Acc=0.9727, F1=0.9697
  Epoch 100: Train Acc=0.9965, CE=0.3156, Div=0.0001, Val Acc=0.9762, F1=0.9739
  Best Val Acc: 0.9775

[Gated-TPA Results]
  Val Acc: 0.9775
  Test Acc: 0.9789, F1: 0.9767

[Progress: 2/57]

EXPERIMENT: Standing_TO_Sitting_10PCT

Loading Standing_TO_Sitting_10PCT...
  Path: /con



  Epoch  10: Train Acc=0.9295, Loss=0.5523, Val Acc=0.9275, F1=0.9184
  Epoch  20: Train Acc=0.9588, Loss=0.4446, Val Acc=0.9466, F1=0.9413
  Epoch  30: Train Acc=0.9763, Loss=0.3961, Val Acc=0.9639, F1=0.9592
  Epoch  40: Train Acc=0.9853, Loss=0.3672, Val Acc=0.9689, F1=0.9643
  Epoch  50: Train Acc=0.9901, Loss=0.3500, Val Acc=0.9734, F1=0.9697
  Epoch  60: Train Acc=0.9929, Loss=0.3386, Val Acc=0.9773, F1=0.9759
  Epoch  70: Train Acc=0.9948, Loss=0.3304, Val Acc=0.9784, F1=0.9762
  Epoch  80: Train Acc=0.9948, Loss=0.3252, Val Acc=0.9804, F1=0.9778
  Epoch  90: Train Acc=0.9965, Loss=0.3200, Val Acc=0.9827, F1=0.9809
  Epoch 100: Train Acc=0.9965, Loss=0.3176, Val Acc=0.9838, F1=0.9820
  Best Val Acc: 0.9838

[GAP Results]
  Val Acc: 0.9838
  Test Acc: 0.9820, F1: 0.9805

[Training TPA]




  Epoch  10: Train Acc=0.9292, CE=0.5433, Div=0.0410, Val Acc=0.9251, F1=0.9146
  Epoch  20: Train Acc=0.9592, CE=0.4437, Div=0.0070, Val Acc=0.9469, F1=0.9397
  Epoch  30: Train Acc=0.9748, CE=0.4000, Div=0.0031, Val Acc=0.9574, F1=0.9511
  Epoch  40: Train Acc=0.9834, CE=0.3736, Div=0.0014, Val Acc=0.9656, F1=0.9610
  Epoch  50: Train Acc=0.9886, CE=0.3574, Div=0.0007, Val Acc=0.9741, F1=0.9711
  Epoch  60: Train Acc=0.9918, CE=0.3458, Div=0.0006, Val Acc=0.9749, F1=0.9719
  Epoch  70: Train Acc=0.9939, CE=0.3381, Div=0.0006, Val Acc=0.9789, F1=0.9761
  Epoch  80: Train Acc=0.9952, CE=0.3321, Div=0.0004, Val Acc=0.9788, F1=0.9763
  Epoch  90: Train Acc=0.9958, CE=0.3276, Div=0.0002, Val Acc=0.9798, F1=0.9773
  Epoch 100: Train Acc=0.9966, CE=0.3246, Div=0.0002, Val Acc=0.9825, F1=0.9800
  Best Val Acc: 0.9827

[TPA Results]
  Val Acc: 0.9827
  Test Acc: 0.9807, F1: 0.9788

[Training Gated-TPA]




  Epoch  10: Train Acc=0.9299, CE=0.5241, Div=0.0283, Val Acc=0.9300, F1=0.9214
  Epoch  20: Train Acc=0.9618, CE=0.4275, Div=0.0060, Val Acc=0.9498, F1=0.9445
  Epoch  30: Train Acc=0.9778, CE=0.3866, Div=0.0020, Val Acc=0.9637, F1=0.9600
  Epoch  40: Train Acc=0.9859, CE=0.3595, Div=0.0010, Val Acc=0.9703, F1=0.9671
  Epoch  50: Train Acc=0.9910, CE=0.3439, Div=0.0008, Val Acc=0.9749, F1=0.9724
  Epoch  60: Train Acc=0.9932, CE=0.3328, Div=0.0008, Val Acc=0.9765, F1=0.9744
  Epoch  70: Train Acc=0.9945, CE=0.3255, Div=0.0003, Val Acc=0.9787, F1=0.9772
  Epoch  80: Train Acc=0.9954, CE=0.3196, Div=0.0000, Val Acc=0.9795, F1=0.9784
  Epoch  90: Train Acc=0.9961, CE=0.3161, Div=0.0001, Val Acc=0.9806, F1=0.9793
  Epoch 100: Train Acc=0.9966, CE=0.3132, Div=0.0000, Val Acc=0.9797, F1=0.9784
  Best Val Acc: 0.9816

[Gated-TPA Results]
  Val Acc: 0.9816
  Test Acc: 0.9823, F1: 0.9808

[Progress: 3/57]

EXPERIMENT: Standing_TO_Sitting_20PCT

Loading Standing_TO_Sitting_20PCT...
  Path: /con



  Epoch  10: Train Acc=0.9250, Loss=0.5563, Val Acc=0.9226, F1=0.9150
  Epoch  20: Train Acc=0.9578, Loss=0.4456, Val Acc=0.9456, F1=0.9393
  Epoch  30: Train Acc=0.9755, Loss=0.3943, Val Acc=0.9592, F1=0.9550
  Epoch  40: Train Acc=0.9862, Loss=0.3662, Val Acc=0.9696, F1=0.9666
  Epoch  50: Train Acc=0.9912, Loss=0.3481, Val Acc=0.9735, F1=0.9712
  Epoch  60: Train Acc=0.9940, Loss=0.3366, Val Acc=0.9763, F1=0.9743
  Epoch  70: Train Acc=0.9947, Loss=0.3300, Val Acc=0.9770, F1=0.9751
  Epoch  80: Train Acc=0.9966, Loss=0.3232, Val Acc=0.9772, F1=0.9757
  Epoch  90: Train Acc=0.9968, Loss=0.3196, Val Acc=0.9785, F1=0.9767
  Epoch 100: Train Acc=0.9971, Loss=0.3162, Val Acc=0.9792, F1=0.9776
  Best Val Acc: 0.9806

[GAP Results]
  Val Acc: 0.9806
  Test Acc: 0.9786, F1: 0.9775

[Training TPA]




  Epoch  10: Train Acc=0.9274, CE=0.5402, Div=0.0348, Val Acc=0.9266, F1=0.9177
  Epoch  20: Train Acc=0.9596, CE=0.4429, Div=0.0081, Val Acc=0.9477, F1=0.9417
  Epoch  30: Train Acc=0.9752, CE=0.3984, Div=0.0046, Val Acc=0.9633, F1=0.9593
  Epoch  40: Train Acc=0.9849, CE=0.3721, Div=0.0022, Val Acc=0.9683, F1=0.9654
  Epoch  50: Train Acc=0.9906, CE=0.3551, Div=0.0015, Val Acc=0.9741, F1=0.9721
  Epoch  60: Train Acc=0.9933, CE=0.3444, Div=0.0009, Val Acc=0.9727, F1=0.9711
  Epoch  70: Train Acc=0.9948, CE=0.3372, Div=0.0005, Val Acc=0.9779, F1=0.9772
  Epoch  80: Train Acc=0.9952, CE=0.3315, Div=0.0005, Val Acc=0.9763, F1=0.9759
  Epoch  90: Train Acc=0.9968, CE=0.3262, Div=0.0003, Val Acc=0.9800, F1=0.9794
  Epoch 100: Train Acc=0.9973, CE=0.3228, Div=0.0001, Val Acc=0.9797, F1=0.9785
  Best Val Acc: 0.9810

[TPA Results]
  Val Acc: 0.9810
  Test Acc: 0.9791, F1: 0.9770

[Training Gated-TPA]




  Epoch  10: Train Acc=0.9302, CE=0.5232, Div=0.0318, Val Acc=0.9254, F1=0.9178
  Epoch  20: Train Acc=0.9611, CE=0.4284, Div=0.0069, Val Acc=0.9512, F1=0.9445
  Epoch  30: Train Acc=0.9778, CE=0.3839, Div=0.0023, Val Acc=0.9640, F1=0.9593
  Epoch  40: Train Acc=0.9867, CE=0.3586, Div=0.0016, Val Acc=0.9706, F1=0.9676
  Epoch  50: Train Acc=0.9919, CE=0.3412, Div=0.0016, Val Acc=0.9746, F1=0.9722
  Epoch  60: Train Acc=0.9935, CE=0.3306, Div=0.0012, Val Acc=0.9757, F1=0.9742
  Epoch  70: Train Acc=0.9951, CE=0.3241, Div=0.0004, Val Acc=0.9763, F1=0.9751
  Epoch  80: Train Acc=0.9965, CE=0.3185, Div=0.0003, Val Acc=0.9789, F1=0.9774
  Epoch  90: Train Acc=0.9969, CE=0.3152, Div=0.0001, Val Acc=0.9785, F1=0.9777
  Epoch 100: Train Acc=0.9970, CE=0.3130, Div=0.0001, Val Acc=0.9795, F1=0.9781
  Best Val Acc: 0.9808

[Gated-TPA Results]
  Val Acc: 0.9808
  Test Acc: 0.9777, F1: 0.9763

[Progress: 4/57]

EXPERIMENT: Standing_TO_Sitting_30PCT

Loading Standing_TO_Sitting_30PCT...
  Path: /con



  Epoch  10: Train Acc=0.9258, Loss=0.5586, Val Acc=0.9256, F1=0.9175
  Epoch  20: Train Acc=0.9589, Loss=0.4457, Val Acc=0.9510, F1=0.9442
  Epoch  30: Train Acc=0.9761, Loss=0.3965, Val Acc=0.9629, F1=0.9576
  Epoch  40: Train Acc=0.9855, Loss=0.3676, Val Acc=0.9703, F1=0.9672
  Epoch  50: Train Acc=0.9903, Loss=0.3506, Val Acc=0.9744, F1=0.9716
  Epoch  60: Train Acc=0.9934, Loss=0.3380, Val Acc=0.9779, F1=0.9758
  Epoch  70: Train Acc=0.9953, Loss=0.3297, Val Acc=0.9806, F1=0.9793
  Epoch  80: Train Acc=0.9961, Loss=0.3242, Val Acc=0.9810, F1=0.9792
  Epoch  90: Train Acc=0.9969, Loss=0.3195, Val Acc=0.9800, F1=0.9782
  Epoch 100: Train Acc=0.9971, Loss=0.3161, Val Acc=0.9801, F1=0.9782
  Early stopping at epoch 100
  Best Val Acc: 0.9810

[GAP Results]
  Val Acc: 0.9810
  Test Acc: 0.9778, F1: 0.9741

[Training TPA]




  Epoch  10: Train Acc=0.9274, CE=0.5411, Div=0.0372, Val Acc=0.9259, F1=0.9185
  Epoch  20: Train Acc=0.9594, CE=0.4427, Div=0.0080, Val Acc=0.9460, F1=0.9404
  Epoch  30: Train Acc=0.9746, CE=0.4007, Div=0.0034, Val Acc=0.9618, F1=0.9577
  Epoch  40: Train Acc=0.9840, CE=0.3741, Div=0.0022, Val Acc=0.9687, F1=0.9663
  Epoch  50: Train Acc=0.9896, CE=0.3565, Div=0.0013, Val Acc=0.9699, F1=0.9669
  Epoch  60: Train Acc=0.9928, CE=0.3441, Div=0.0010, Val Acc=0.9760, F1=0.9734
  Epoch  70: Train Acc=0.9946, CE=0.3371, Div=0.0010, Val Acc=0.9747, F1=0.9727
  Epoch  80: Train Acc=0.9950, CE=0.3320, Div=0.0006, Val Acc=0.9791, F1=0.9770
  Epoch  90: Train Acc=0.9962, CE=0.3268, Div=0.0002, Val Acc=0.9792, F1=0.9770
  Epoch 100: Train Acc=0.9967, CE=0.3238, Div=0.0002, Val Acc=0.9798, F1=0.9782
  Best Val Acc: 0.9807

[TPA Results]
  Val Acc: 0.9807
  Test Acc: 0.9784, F1: 0.9755

[Training Gated-TPA]




  Epoch  10: Train Acc=0.9294, CE=0.5283, Div=0.0345, Val Acc=0.9244, F1=0.9187
  Epoch  20: Train Acc=0.9608, CE=0.4309, Div=0.0080, Val Acc=0.9425, F1=0.9372
  Epoch  30: Train Acc=0.9765, CE=0.3891, Div=0.0030, Val Acc=0.9613, F1=0.9571
  Epoch  40: Train Acc=0.9863, CE=0.3614, Div=0.0019, Val Acc=0.9718, F1=0.9697
  Epoch  50: Train Acc=0.9905, CE=0.3441, Div=0.0011, Val Acc=0.9754, F1=0.9731
  Epoch  60: Train Acc=0.9935, CE=0.3319, Div=0.0010, Val Acc=0.9782, F1=0.9767
  Epoch  70: Train Acc=0.9951, CE=0.3244, Div=0.0004, Val Acc=0.9791, F1=0.9775
  Epoch  80: Train Acc=0.9961, CE=0.3195, Div=0.0002, Val Acc=0.9803, F1=0.9786
  Epoch  90: Train Acc=0.9969, CE=0.3151, Div=0.0002, Val Acc=0.9820, F1=0.9805
  Epoch 100: Train Acc=0.9972, CE=0.3125, Div=0.0001, Val Acc=0.9820, F1=0.9808
  Best Val Acc: 0.9820

[Gated-TPA Results]
  Val Acc: 0.9820
  Test Acc: 0.9792, F1: 0.9759

[Progress: 5/57]

EXPERIMENT: Standing_TO_Sitting_40PCT

Loading Standing_TO_Sitting_40PCT...
  Path: /con



  Epoch  10: Train Acc=0.9251, Loss=0.5621, Val Acc=0.9210, F1=0.9117
