In [None]:
# 결측치에 강인한 HAR

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import json
from collections import defaultdict

# ============================================================================
# 1. UCI-HAR Raw 데이터 로더
# ============================================================================
class UCIHARRawDataset(Dataset):
    """
    UCI-HAR Raw 데이터셋
    - 128 timesteps × 9 channels (3-axis accelerometer + 3-axis gyroscope)
    """
    def __init__(self, data_dir, split='train'):
        self.data_dir = Path(data_dir)
        self.split = split

        # Load data
        if split == 'train':
            self.data = self._load_split('train')
        else:
            self.data = self._load_split('test')

        self.signals, self.labels = self.data

    def _load_split(self, split):
        """Load UCI-HAR raw inertial signals"""
        signals_dir = self.data_dir / split / 'Inertial Signals'

        # Signal files
        signal_types = [
            'body_acc_x', 'body_acc_y', 'body_acc_z',
            'body_gyro_x', 'body_gyro_y', 'body_gyro_z',
            'total_acc_x', 'total_acc_y', 'total_acc_z'
        ]

        signals = []
        for sig_type in signal_types:
            filename = signals_dir / f'{sig_type}_{split}.txt'
            data = np.loadtxt(filename)
            signals.append(data)

        # Stack: [N, 128, 9]
        signals = np.stack(signals, axis=-1)

        # Load labels
        labels_file = self.data_dir / split / f'y_{split}.txt'
        labels = np.loadtxt(labels_file, dtype=np.int32) - 1  # 0-indexed

        return signals, labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = torch.FloatTensor(self.signals[idx])  # [128, 9]
        y = torch.LongTensor([self.labels[idx]])[0]
        return x, y


# ============================================================================
# 2. Augmentation Strategies (데이터/증강 관점 베이스라인)
# ============================================================================
class TemporalMaskingBatch:
    """연속 구간 마스킹 (Ours)"""
    def __init__(self, mask_ratio=0.3, mask_length_range=(10, 30)):
        self.mask_ratio = mask_ratio
        self.mask_length_range = mask_length_range

    def __call__(self, x):
        """
        Args:
            x: [B, T, C] tensor
        Returns:
            masked_x: [B, T, C], mask: [B, T]
        """
        B, T, C = x.shape
        mask = torch.ones(B, T, dtype=torch.bool, device=x.device)

        max_len = min(self.mask_length_range[1], T)
        min_len = min(self.mask_length_range[0], T)

        for b in range(B):
            num_to_mask = int(T * self.mask_ratio)

            while num_to_mask > 0:
                length = np.random.randint(min_len, max_len + 1)
                length = min(length, num_to_mask, T)
                start = np.random.randint(0, T - length + 1)
                mask[b, start:start + length] = False
                num_to_mask -= length

        masked_x = x.clone()
        masked_x[~mask] = 0.0

        return masked_x, mask


class RandomPointDrop:
    """랜덤 포인트 드롭 (비연속 결측)"""
    def __init__(self, drop_ratio=0.3):
        self.drop_ratio = drop_ratio

    def __call__(self, x):
        """
        Args:
            x: [B, T, C] tensor
        Returns:
            dropped_x: [B, T, C], mask: [B, T]
        """
        B, T, C = x.shape
        mask = torch.rand(B, T, device=x.device) > self.drop_ratio

        dropped_x = x.clone()
        dropped_x[~mask] = 0.0

        return dropped_x, mask


class ChannelDrop:
    """채널/센서 드롭 (센서 고장 시뮬레이션)"""
    def __init__(self, drop_prob=0.3, mode='random'):
        """
        Args:
            drop_prob: probability to drop channels
            mode: 'random' (individual channels), 'axis' (3-channel groups),
                  'sensor' (acc/gyro groups)
        """
        self.drop_prob = drop_prob
        self.mode = mode

    def __call__(self, x):
        """
        Args:
            x: [B, T, C] tensor (C=9: acc_xyz, gyro_xyz, total_acc_xyz)
        Returns:
            dropped_x: [B, T, C], mask: [B, C] (channel-wise mask)
        """
        B, T, C = x.shape

        if self.mode == 'random':
            # Individual channel drop
            channel_mask = torch.rand(B, C, device=x.device) > self.drop_prob

        elif self.mode == 'axis':
            # Drop entire axis (3 channels at once: x, y, or z)
            # Channels: [acc_x, acc_y, acc_z, gyro_x, gyro_y, gyro_z, total_x, total_y, total_z]
            # Groups: [0,3,6], [1,4,7], [2,5,8]
            axis_groups = [[0,3,6], [1,4,7], [2,5,8]]
            channel_mask = torch.ones(B, C, dtype=torch.bool, device=x.device)
            for b in range(B):
                for group in axis_groups:
                    if np.random.rand() < self.drop_prob:
                        channel_mask[b, group] = False

        elif self.mode == 'sensor':
            # Drop entire sensor (acc: 0-2, gyro: 3-5, total_acc: 6-8)
            sensor_groups = [[0,1,2], [3,4,5], [6,7,8]]
            channel_mask = torch.ones(B, C, dtype=torch.bool, device=x.device)
            for b in range(B):
                for group in sensor_groups:
                    if np.random.rand() < self.drop_prob:
                        channel_mask[b, group] = False

        dropped_x = x.clone()
        # Expand mask: [B, C] → [B, T, C]
        dropped_x *= channel_mask.unsqueeze(1).float()

        # For compatibility, return temporal mask (all True if any channel observed)
        temporal_mask = channel_mask.any(dim=1, keepdim=True).expand(B, T)

        return dropped_x, temporal_mask


class StandardAugmentation:
    """표준 SSL 증강 (결측 없음 - SimCLR/BYOL 스타일)"""
    def __init__(self, noise_std=0.05, scale_range=(0.8, 1.2)):
        self.noise_std = noise_std
        self.scale_range = scale_range

    def __call__(self, x):
        """
        Args:
            x: [B, T, C] tensor
        Returns:
            augmented_x: [B, T, C], mask: [B, T] (all True)
        """
        B, T, C = x.shape

        # Gaussian noise
        noise = torch.randn_like(x) * self.noise_std
        augmented_x = x + noise

        # Random scaling
        scale = torch.empty(B, 1, C, device=x.device).uniform_(*self.scale_range)
        augmented_x = augmented_x * scale

        # No masking (all observed)
        mask = torch.ones(B, T, dtype=torch.bool, device=x.device)

        return augmented_x, mask


# ============================================================================
# 3. ELK Backbone
# ============================================================================
class ELKBlock(nn.Module):
    """Efficient Large Kernel Block with structural reparameterization."""
    def __init__(self, in_channels, out_channels, kernel_size=31, deploy=False):
        super().__init__()
        self.deploy = deploy
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        padding_large1 = kernel_size // 2
        kernel_size_large2 = kernel_size - 2
        padding_large2 = kernel_size_large2 // 2
        kernel_size_small1 = 5
        padding_small1 = kernel_size_small1 // 2
        kernel_size_small2 = 3
        padding_small2 = kernel_size_small2 // 2

        if deploy:
            self.reparam_conv = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=True
            )
        else:
            self.dw_large1 = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=False
            )
            self.bn_large1 = nn.BatchNorm1d(in_channels)
            self.dw_large2 = nn.Conv1d(
                in_channels, in_channels, kernel_size_large2,
                padding=padding_large2, groups=in_channels, bias=False
            )
            self.bn_large2 = nn.BatchNorm1d(in_channels)
            self.dw_small1 = nn.Conv1d(
                in_channels, in_channels, kernel_size_small1,
                padding=padding_small1, groups=in_channels, bias=False
            )
            self.bn_small1 = nn.BatchNorm1d(in_channels)
            self.dw_small2 = nn.Conv1d(
                in_channels, in_channels, kernel_size_small2,
                padding=padding_small2, groups=in_channels, bias=False
            )
            self.bn_small2 = nn.BatchNorm1d(in_channels)
            self.bn_id = nn.BatchNorm1d(in_channels)

        self.pointwise = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_channels),
        )
        self.activation = nn.GELU()

    def forward(self, x):
        if self.deploy:
            x = self.reparam_conv(x)
        else:
            x1 = self.bn_large1(self.dw_large1(x))
            x2 = self.bn_large2(self.dw_large2(x))
            x3 = self.bn_small1(self.dw_small1(x))
            x4 = self.bn_small2(self.dw_small2(x))
            x5 = self.bn_id(x)
            x = x1 + x2 + x3 + x4 + x5
        x = self.activation(x)
        return self.pointwise(x)


class ELKBackbone(nn.Module):
    """ELK Backbone built by stacking ELKBlocks"""
    def __init__(self, in_channels=9, d_model=128, num_layers=6, kernel_size=31, dropout=0.1):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(in_channels, d_model, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        layers = []
        for _ in range(num_layers):
            layers.append(ELKBlock(d_model, d_model, kernel_size=kernel_size))
            layers.append(nn.Dropout(dropout))
        self.elk_layers = nn.Sequential(*layers)
        self.out_channels = d_model

    def forward(self, x):
        x = self.stem(x)
        x = self.elk_layers(x)
        return x


# ============================================================================
# 4. Encoder with Masked Pooling
# ============================================================================
def masked_global_avg_pool(h, mask):
    """Mask-aware global average pooling"""
    mask_float = mask.unsqueeze(1).float()  # [B, 1, T]
    numerator = (h * mask_float).sum(dim=-1)
    denominator = mask_float.sum(dim=-1).clamp_min(1e-6)
    return numerator / denominator


class ELKEncoder(nn.Module):
    """ELK-based encoder with mask-aware pooling"""
    def __init__(self, in_channels=9, d_model=128, num_layers=6,
                 kernel_size=31, output_dim=256, dropout=0.1):
        super().__init__()
        self.backbone = ELKBackbone(in_channels, d_model, num_layers, kernel_size, dropout)
        self.projection = nn.Sequential(
            nn.Linear(d_model, output_dim),
            nn.BatchNorm1d(output_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, x, mask=None):
        x = x.transpose(1, 2)  # [B, T, C] → [B, C, T]
        h = self.backbone(x)
        if mask is not None:
            h = masked_global_avg_pool(h, mask)
        else:
            h = h.mean(dim=-1)
        z = self.projection(h)
        z = F.normalize(z, dim=1)
        return z


# ============================================================================
# 5. Loss Functions
# ============================================================================
class MTCLoss(nn.Module):
    """Masked Temporal Consistency Loss"""
    def forward(self, z_clean, z_masked):
        return F.mse_loss(z_clean, z_masked)


class SymmetricNTXentLoss(nn.Module):
    """Symmetric NT-Xent Loss"""
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        B = z1.shape[0]
        device = z1.device

        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        sim_11 = (z1 @ z1.T) / self.temperature
        sim_22 = (z2 @ z2.T) / self.temperature
        sim_12 = (z1 @ z2.T) / self.temperature
        sim_21 = sim_12.T

        mask = torch.eye(B, device=device, dtype=torch.bool)
        sim_11 = sim_11.masked_fill(mask, -9e15)
        sim_22 = sim_22.masked_fill(mask, -9e15)

        logits_1 = torch.cat([sim_12, sim_11], dim=1)
        logits_2 = torch.cat([sim_21, sim_22], dim=1)
        labels = torch.arange(B, device=device)

        loss_1 = F.cross_entropy(logits_1, labels)
        loss_2 = F.cross_entropy(logits_2, labels)

        return 0.5 * (loss_1 + loss_2)


# ============================================================================
# 6. 통합 프레임워크 (모든 베이스라인 지원)
# ============================================================================
class UnifiedSSLFramework(nn.Module):
    """
    Unified SSL Framework supporting multiple baselines:
    1. SL-Only (Supervised Learning)
    2. SSL w/o Missing (Standard Contrastive)
    3. Random Point Drop
    4. Channel Drop
    5. MTC + MICL (Ours)
    """
    def __init__(self, method='mtc_micl', in_channels=9, d_model=128,
                 num_layers=6, kernel_size=31, output_dim=256, dropout=0.1,
                 mask_ratio=0.3, temperature=0.07,
                 lambda_mtc=1.0, lambda_micl=1.0, channel_drop_mode='random'):
        super().__init__()

        self.method = method
        self.encoder = ELKEncoder(in_channels, d_model, num_layers,
                                   kernel_size, output_dim, dropout)

        # Augmentation strategy
        if method == 'sl_only':
            self.augmentation = None  # No SSL
        elif method == 'ssl_wo_missing':
            self.augmentation = StandardAugmentation()
        elif method == 'random_point_drop':
            self.augmentation = RandomPointDrop(drop_ratio=mask_ratio)
        elif method == 'channel_drop':
            self.augmentation = ChannelDrop(drop_prob=mask_ratio, mode=channel_drop_mode)
        elif method == 'mtc_micl':
            self.augmentation = TemporalMaskingBatch(mask_ratio=mask_ratio)
        else:
            raise ValueError(f"Unknown method: {method}")

        # Loss functions
        self.mtc_loss = MTCLoss()
        self.micl_loss = SymmetricNTXentLoss(temperature=temperature)
        self.lambda_mtc = lambda_mtc
        self.lambda_micl = lambda_micl

    def forward(self, x):
        if self.method == 'sl_only':
            # No SSL pretraining - return dummy loss
            return torch.tensor(0.0, device=x.device), {'total': 0.0, 'mtc': 0.0, 'micl': 0.0}

        # Clean view
        z_clean = self.encoder(x, mask=None)

        # Augmented view
        x_aug, mask = self.augmentation(x)
        z_aug = self.encoder(x_aug, mask=mask)

        # Compute losses based on method
        if self.method == 'ssl_wo_missing':
            # Only contrastive loss (no MTC)
            loss_mtc = torch.tensor(0.0, device=x.device)
            loss_micl = self.micl_loss(z_clean, z_aug)
            loss = loss_micl
        else:
            # MTC + MICL
            loss_mtc = self.mtc_loss(z_clean, z_aug)
            loss_micl = self.micl_loss(z_clean, z_aug)
            loss = self.lambda_mtc * loss_mtc + self.lambda_micl * loss_micl

        losses_dict = {
            'total': loss.item(),
            'mtc': loss_mtc.item(),
            'micl': loss_micl.item()
        }

        return loss, losses_dict

    def get_representation(self, x, mask=None):
        with torch.no_grad():
            return self.encoder(x, mask=mask)


# ============================================================================
# 7. Linear Evaluation Protocol
# ============================================================================
class LinearClassifier(nn.Module):
    """Linear evaluation protocol"""
    def __init__(self, encoder, num_classes=6):
        super().__init__()
        self.encoder = encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        with torch.no_grad():
            z = self.encoder(x, mask=None)
        return self.classifier(z)


# ============================================================================
# 8. Training & Evaluation Functions
# ============================================================================
def train_ssl_epoch(model, dataloader, optimizer, device):
    """SSL pretraining epoch"""
    model.train()
    total_loss = 0
    total_mtc = 0
    total_micl = 0

    for x, _ in dataloader:
        x = x.to(device)
        optimizer.zero_grad()
        loss, losses_dict = model(x)
        if loss.item() > 0:  # Skip SL-only
            loss.backward()
            optimizer.step()
        total_loss += losses_dict['total']
        total_mtc += losses_dict['mtc']
        total_micl += losses_dict['micl']

    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'mtc': total_mtc / num_batches,
        'micl': total_micl / num_batches
    }


def train_linear_epoch(model, dataloader, optimizer, criterion, device):
    """Linear evaluation training epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total


def evaluate_linear(model, dataloader, device):
    """Linear evaluation testing"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    return 100.0 * correct / total


def evaluate_missing_robustness(model, dataloader, device, missing_ratios=[0.0, 0.3, 0.5, 0.7]):
    """
    Evaluate robustness under different missing ratios at test time
    """
    model.eval()
    results = {}

    for ratio in missing_ratios:
        masking = TemporalMaskingBatch(mask_ratio=ratio)
        correct = 0
        total = 0

        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(device), y.to(device)

                if ratio > 0:
                    x_masked, mask = masking(x)
                    logits = model(x_masked)
                else:
                    logits = model(x)

                pred = logits.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        results[ratio] = 100.0 * correct / total

    return results


# ============================================================================
# 9. Comprehensive Benchmark
# ============================================================================
def run_comprehensive_benchmark(data_dir, device, ssl_epochs=50, linear_epochs=50):
    """
    Run comprehensive benchmark comparing all methods
    """
    print("="*80)
    print("COMPREHENSIVE BENCHMARK: ELK-MTC-MICL vs Baselines")
    print("="*80)

    # Dataset
    train_dataset = UCIHARRawDataset(data_dir, split='train')
    test_dataset = UCIHARRawDataset(data_dir, split='test')
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    methods = [
        {'name': 'SL-Only', 'method': 'sl_only', 'desc': 'Supervised learning (no SSL)'},
        {'name': 'SSL w/o Missing', 'method': 'ssl_wo_missing', 'desc': 'Standard contrastive (SimCLR-style)'},
        {'name': 'Random Point Drop', 'method': 'random_point_drop', 'desc': 'Non-contiguous missing'},
        {'name': 'Channel Drop (Random)', 'method': 'channel_drop', 'desc': 'Random channel missing', 'channel_mode': 'random'},
        {'name': 'Channel Drop (Sensor)', 'method': 'channel_drop', 'desc': 'Sensor-wise missing', 'channel_mode': 'sensor'},
        {'name': 'MTC + MICL (Ours)', 'method': 'mtc_micl', 'desc': 'Contiguous temporal masking + joint loss'},
    ]

    results = defaultdict(dict)

    for config in methods:
        print(f"\n{'='*80}")
        print(f"Method: {config['name']}")
        print(f"Description: {config['desc']}")
        print(f"{'='*80}")

        # SSL Pretraining
        if config['method'] != 'sl_only':
            print(f"\n[Phase 1] SSL Pretraining ({ssl_epochs} epochs)...")

            ssl_model = UnifiedSSLFramework(
                method=config['method'],
                channel_drop_mode=config.get('channel_mode', 'random')
            ).to(device)

            optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=1e-3, weight_decay=1e-5)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ssl_epochs)

            for epoch in range(ssl_epochs):
                metrics = train_ssl_epoch(ssl_model, train_loader, optimizer, device)
                scheduler.step()

                if (epoch + 1) % 10 == 0:
                    print(f"  Epoch {epoch+1}/{ssl_epochs}: Loss={metrics['loss']:.4f}, "
                          f"MTC={metrics['mtc']:.4f}, MICL={metrics['micl']:.4f}")

            encoder = ssl_model.encoder
        else:
            # SL-Only: random init encoder
            encoder = ELKEncoder().to(device)

        # Linear Evaluation
        print(f"\n[Phase 2] Linear Evaluation ({linear_epochs} epochs)...")
        linear_model = LinearClassifier(encoder, num_classes=6).to(device)
        optimizer = torch.optim.Adam(linear_model.classifier.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        best_test_acc = 0.0
        for epoch in range(linear_epochs):
            train_loss, train_acc = train_linear_epoch(
                linear_model, train_loader, optimizer, criterion, device
            )
            test_acc = evaluate_linear(linear_model, test_loader, device)

            if test_acc > best_test_acc:
                best_test_acc = test_acc

            if (epoch + 1) % 10 == 0:
                print(f"  Epoch {epoch+1}/{linear_epochs}: Train Acc={train_acc:.2f}%, "
                      f"Test Acc={test_acc:.2f}%")

        results[config['name']]['test_accuracy'] = best_test_acc

        # Missing Robustness Test
        print(f"\n[Phase 3] Missing Robustness Test...")
        missing_results = evaluate_missing_robustness(
            linear_model, test_loader, device,
            missing_ratios=[0.0, 0.3, 0.5, 0.7]
        )
        results[config['name']]['missing_robustness'] = missing_results

        for ratio, acc in missing_results.items():
            print(f"  Missing Ratio {ratio:.1f}: {acc:.2f}%")

        # Save model
        torch.save({
            'encoder_state_dict': encoder.state_dict(),
            'classifier_state_dict': linear_model.classifier.state_dict(),
            'config': config,
            'results': results[config['name']]
        }, f"benchmark_{config['method']}.pth")

    # Summary
    print(f"\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")
    print(f"{'Method':<25} {'Clean Acc':<12} {'30% Miss':<12} {'50% Miss':<12} {'70% Miss':<12}")
    print("-"*80)

    for method_name, result in results.items():
        clean_acc = result['missing_robustness'][0.0]
        miss_30 = result['missing_robustness'][0.3]
        miss_50 = result['missing_robustness'][0.5]
        miss_70 = result['missing_robustness'][0.7]
        print(f"{method_name:<25} {clean_acc:>10.2f}% {miss_30:>10.2f}% "
              f"{miss_50:>10.2f}% {miss_70:>10.2f}%")

    # Save results
    with open('benchmark_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\n✓ Results saved to: benchmark_results.json")

    return results


# ============================================================================
# 10. Main Entry Point
# ============================================================================
def main():
    data_dir = '/content/drive/MyDrive/Colab Notebooks/UCI-HAR/UCI-HAR'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Device: {device}")
    print(f"PyTorch version: {torch.__version__}\n")

    # Run comprehensive benchmark
    results = run_comprehensive_benchmark(
        data_dir=data_dir,
        device=device,
        ssl_epochs=50,
        linear_epochs=50
    )

    # Generate comparison plots
    try:
        import matplotlib.pyplot as plt

        # Plot 1: Test Accuracy Comparison
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        methods = list(results.keys())
        accuracies = [results[m]['test_accuracy'] for m in methods]
        colors = ['red', 'orange', 'yellow', 'lightblue', 'blue', 'green']
        bars = plt.bar(range(len(methods)), accuracies, color=colors, alpha=0.7)
        plt.xticks(range(len(methods)), methods, rotation=45, ha='right')
        plt.ylabel('Test Accuracy (%)')
        plt.title('Linear Evaluation: Test Accuracy')
        plt.ylim([0, 100])
        plt.grid(axis='y', alpha=0.3)

        # Add value labels on bars
        for i, (bar, acc) in enumerate(zip(bars, accuracies)):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{acc:.1f}%', ha='center', va='bottom', fontsize=9)

        # Plot 2: Missing Robustness
        plt.subplot(1, 2, 2)
        missing_ratios = [0.0, 0.3, 0.5, 0.7]
        for method in methods:
            accs = [results[method]['missing_robustness'][r] for r in missing_ratios]
            plt.plot(missing_ratios, accs, marker='o', label=method, linewidth=2)

        plt.xlabel('Missing Ratio')
        plt.ylabel('Test Accuracy (%)')
        plt.title('Robustness to Missing Data')
        plt.legend(loc='lower left', fontsize=8)
        plt.grid(alpha=0.3)
        plt.ylim([0, 100])

        plt.tight_layout()
        plt.savefig('benchmark_comparison.png', dpi=300, bbox_inches='tight')
        print(f"\n✓ Comparison plots saved to: benchmark_comparison.png")

    except ImportError:
        print("\n⚠ Matplotlib not available. Skipping plots.")

    print("\n" + "="*80)
    print("BENCHMARK COMPLETE!")
    print("="*80)


# ============================================================================
# 11. Additional Utility: Single Method Training
# ============================================================================
def train_single_method(method='mtc_micl', data_dir='./UCI HAR Dataset',
                       ssl_epochs=100, linear_epochs=50, device=None):
    """
    Train a single method (for detailed experimentation)

    Args:
        method: one of ['sl_only', 'ssl_wo_missing', 'random_point_drop',
                        'channel_drop', 'mtc_micl']
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"\n{'='*80}")
    print(f"Training Method: {method}")
    print(f"{'='*80}\n")

    # Dataset
    train_dataset = UCIHARRawDataset(data_dir, split='train')
    test_dataset = UCIHARRawDataset(data_dir, split='test')
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    # SSL Pretraining
    if method != 'sl_only':
        print(f"[Phase 1] SSL Pretraining ({ssl_epochs} epochs)...\n")

        ssl_model = UnifiedSSLFramework(method=method).to(device)
        optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=1e-3, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ssl_epochs)

        for epoch in range(ssl_epochs):
            metrics = train_ssl_epoch(ssl_model, train_loader, optimizer, device)
            scheduler.step()

            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{ssl_epochs}: Loss={metrics['loss']:.4f}, "
                      f"MTC={metrics['mtc']:.4f}, MICL={metrics['micl']:.4f}")

        encoder = ssl_model.encoder
        torch.save(encoder.state_dict(), f'{method}_encoder.pth')
        print(f"\n✓ Encoder saved: {method}_encoder.pth")
    else:
        encoder = ELKEncoder().to(device)

    # Linear Evaluation
    print(f"\n[Phase 2] Linear Evaluation ({linear_epochs} epochs)...\n")
    linear_model = LinearClassifier(encoder, num_classes=6).to(device)
    optimizer = torch.optim.Adam(linear_model.classifier.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    best_test_acc = 0.0
    for epoch in range(linear_epochs):
        train_loss, train_acc = train_linear_epoch(
            linear_model, train_loader, optimizer, criterion, device
        )
        test_acc = evaluate_linear(linear_model, test_loader, device)

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save(linear_model.state_dict(), f'{method}_best_classifier.pth')

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{linear_epochs}: Train Acc={train_acc:.2f}%, "
                  f"Test Acc={test_acc:.2f}%")

    print(f"\n✓ Best Test Accuracy: {best_test_acc:.2f}%")
    print(f"✓ Best model saved: {method}_best_classifier.pth")

    # Missing Robustness
    print(f"\n[Phase 3] Missing Robustness Test...\n")
    missing_results = evaluate_missing_robustness(
        linear_model, test_loader, device, missing_ratios=[0.0, 0.3, 0.5, 0.7]
    )

    print("Missing Ratio | Test Accuracy")
    print("-" * 30)
    for ratio, acc in missing_results.items():
        print(f"    {ratio:.1f}      |    {acc:.2f}%")

    return {
        'test_accuracy': best_test_acc,
        'missing_robustness': missing_results
    }


# ============================================================================
# 12. Ablation Study: Loss Components
# ============================================================================
def run_loss_ablation(data_dir='./UCI HAR Dataset', device=None):
    """
    Ablation study on loss components
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"\n{'='*80}")
    print("ABLATION STUDY: Loss Components")
    print(f"{'='*80}\n")

    train_dataset = UCIHARRawDataset(data_dir, split='train')
    test_dataset = UCIHARRawDataset(data_dir, split='test')
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    configs = [
        {'name': 'MTC only', 'lambda_mtc': 1.0, 'lambda_micl': 0.0},
        {'name': 'MICL only', 'lambda_mtc': 0.0, 'lambda_micl': 1.0},
        {'name': 'MTC + MICL (0.5:0.5)', 'lambda_mtc': 0.5, 'lambda_micl': 0.5},
        {'name': 'MTC + MICL (1.0:1.0)', 'lambda_mtc': 1.0, 'lambda_micl': 1.0},
        {'name': 'MTC + MICL (1.0:0.5)', 'lambda_mtc': 1.0, 'lambda_micl': 0.5},
    ]

    results = {}

    for config in configs:
        print(f"\nTraining: {config['name']}")
        print(f"λ_MTC={config['lambda_mtc']}, λ_MICL={config['lambda_micl']}")
        print("-" * 40)

        # SSL
        ssl_model = UnifiedSSLFramework(
            method='mtc_micl',
            lambda_mtc=config['lambda_mtc'],
            lambda_micl=config['lambda_micl']
        ).to(device)

        optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=1e-3, weight_decay=1e-5)

        for epoch in range(50):
            metrics = train_ssl_epoch(ssl_model, train_loader, optimizer, device)
            if (epoch + 1) % 10 == 0:
                print(f"  Epoch {epoch+1}: Loss={metrics['loss']:.4f}")

        # Linear eval
        linear_model = LinearClassifier(ssl_model.encoder, num_classes=6).to(device)
        optimizer = torch.optim.Adam(linear_model.classifier.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        best_acc = 0.0
        for epoch in range(50):
            _, _ = train_linear_epoch(linear_model, train_loader, optimizer, criterion, device)
            test_acc = evaluate_linear(linear_model, test_loader, device)
            best_acc = max(best_acc, test_acc)

        results[config['name']] = best_acc
        print(f"  → Best Test Accuracy: {best_acc:.2f}%")

    # Summary
    print(f"\n{'='*60}")
    print("Ablation Study Summary")
    print(f"{'='*60}")
    for name, acc in results.items():
        print(f"{name:<30} {acc:>10.2f}%")

    return results


# ============================================================================
# 10. Main Entry Point
# ============================================================================
def main():
    data_dir = '/content/drive/MyDrive/Colab Notebooks/UCI-HAR/UCI-HAR'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Device: {device}")
    print(f"PyTorch version: {torch.__version__}\n")

    # Run comprehensive benchmark
    results = run_comprehensive_benchmark(
        data_dir=data_dir,
        device=device,
        ssl_epochs=50,
        linear_epochs=50
    )

    return results


if __name__ == '__main__':
    main()

Device: cuda
PyTorch version: 2.8.0+cu126

COMPREHENSIVE BENCHMARK: ELK-MTC-MICL vs Baselines

Method: SL-Only
Description: Supervised learning (no SSL)

[Phase 2] Linear Evaluation (50 epochs)...
  Epoch 10/50: Train Acc=74.96%, Test Acc=58.30%
  Epoch 20/50: Train Acc=77.04%, Test Acc=61.79%
  Epoch 30/50: Train Acc=78.20%, Test Acc=65.56%
  Epoch 40/50: Train Acc=79.03%, Test Acc=65.52%
  Epoch 50/50: Train Acc=79.62%, Test Acc=69.63%

[Phase 3] Missing Robustness Test...
  Missing Ratio 0.0: 69.63%
  Missing Ratio 0.3: 63.42%
  Missing Ratio 0.5: 54.94%
  Missing Ratio 0.7: 42.93%

Method: SSL w/o Missing
Description: Standard contrastive (SimCLR-style)

[Phase 1] SSL Pretraining (50 epochs)...
  Epoch 10/50: Loss=0.1688, MTC=0.0000, MICL=0.1688
  Epoch 20/50: Loss=0.1055, MTC=0.0000, MICL=0.1055
  Epoch 30/50: Loss=0.0817, MTC=0.0000, MICL=0.0817
  Epoch 40/50: Loss=0.0634, MTC=0.0000, MICL=0.0634
  Epoch 50/50: Loss=0.0619, MTC=0.0000, MICL=0.0619

[Phase 2] Linear Evaluation (50



  Epoch 30/50: Loss=0.0701, MTC=0.0009, MICL=0.0692




  Epoch 40/50: Loss=0.0568, MTC=0.0009, MICL=0.0559
  Epoch 50/50: Loss=0.0480, MTC=0.0009, MICL=0.0471

[Phase 2] Linear Evaluation (50 epochs)...
  Epoch 10/50: Train Acc=54.35%, Test Acc=45.30%
  Epoch 20/50: Train Acc=58.38%, Test Acc=48.52%
  Epoch 30/50: Train Acc=60.95%, Test Acc=52.02%
  Epoch 40/50: Train Acc=63.32%, Test Acc=53.38%
  Epoch 50/50: Train Acc=64.00%, Test Acc=54.02%

[Phase 3] Missing Robustness Test...
  Missing Ratio 0.0: 54.02%
  Missing Ratio 0.3: 53.85%
  Missing Ratio 0.5: 49.75%
  Missing Ratio 0.7: 48.63%

BENCHMARK SUMMARY
Method                    Clean Acc    30% Miss     50% Miss     70% Miss    
--------------------------------------------------------------------------------
SL-Only                        69.63%      63.42%      54.94%      42.93%
SSL w/o Missing                71.73%      48.59%      41.77%      35.05%
Random Point Drop              54.05%      49.37%      43.94%      39.43%
Channel Drop (Random)          87.92%      64.68%      48

 Missing + Noise-Only 에 강인한 HAR

---

- Group A: Missing-Only (4개)

✓ temporal_30       # 연속 구간 30% 결측

✓ temporal_50       # 연속 구간 50% 결측

✓ temporal_70       # 연속 구간 70% 결측

✓ point_30          # 비연속 포인트 30% 결측

- Group B: Noise-Only (4개)

✓ gaussian_20db     # 가우시안 노이즈 (낮음, SNR=20dB)

✓ gaussian_10db     # 가우시안 노이즈 (중간, SNR=10dB)

✓ gaussian_5db      # 가우시안 노이즈 (높음, SNR=5dB)

✓ salt_pepper_10    # 충격 잡음 10%

- Group C: Combined (4개)

✓ temporal30_gaussian10      # 결측 30% + 가우시안 10dB

✓ temporal50_gaussian10      # 결측 50% + 가우시안 10dB

✓ temporal30_saltpepper      # 결측 30% + 충격 잡음 10%

✓ temporal50_saltpepper      # 결측 50% + 충격 잡음 10%

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import json
from collections import defaultdict

# ============================================================================
# 1. UCI-HAR Raw 데이터 로더
# ============================================================================
class UCIHARRawDataset(Dataset):
    """UCI-HAR Raw 데이터셋"""
    def __init__(self, data_dir, split='train'):
        self.data_dir = Path(data_dir)
        self.split = split

        if split == 'train':
            self.data = self._load_split('train')
        else:
            self.data = self._load_split('test')

        self.signals, self.labels = self.data

    def _load_split(self, split):
        signals_dir = self.data_dir / split / 'Inertial Signals'

        signal_types = [
            'body_acc_x', 'body_acc_y', 'body_acc_z',
            'body_gyro_x', 'body_gyro_y', 'body_gyro_z',
            'total_acc_x', 'total_acc_y', 'total_acc_z'
        ]

        signals = []
        for sig_type in signal_types:
            filename = signals_dir / f'{sig_type}_{split}.txt'
            data = np.loadtxt(filename)
            signals.append(data)

        signals = np.stack(signals, axis=-1)

        labels_file = self.data_dir / split / f'y_{split}.txt'
        labels = np.loadtxt(labels_file, dtype=np.int32) - 1

        return signals, labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = torch.FloatTensor(self.signals[idx])
        y = torch.LongTensor([self.labels[idx]])[0]
        return x, y


# ============================================================================
# 2. Noise Generators
# ============================================================================
class GaussianNoise:
    """백색 가우시안 잡음"""
    def __init__(self, snr_db=10):
        self.snr_db = snr_db

    def __call__(self, x):
        signal_power = torch.mean(x ** 2)
        snr_linear = 10 ** (self.snr_db / 10)
        noise_power = signal_power / snr_linear
        noise = torch.randn_like(x) * torch.sqrt(noise_power)
        return x + noise


class SaltPepperNoise:
    """충격 잡음"""
    def __init__(self, corruption_ratio=0.1):
        self.corruption_ratio = corruption_ratio

    def __call__(self, x):
        B, T, C = x.shape
        mask = torch.rand(B, T, C, device=x.device) < self.corruption_ratio
        x_min, x_max = x.min(), x.max()
        salt_or_pepper = torch.rand(B, T, C, device=x.device) > 0.5

        corrupted = x.clone()
        corrupted[mask & salt_or_pepper] = x_max * 2
        corrupted[mask & ~salt_or_pepper] = x_min * 2

        return corrupted


class DriftNoise:
    """센서 드리프트"""
    def __init__(self, drift_rate=0.05, mode='linear'):
        self.drift_rate = drift_rate
        self.mode = mode

    def __call__(self, x):
        B, T, C = x.shape
        t = torch.linspace(0, 1, T, device=x.device).view(1, T, 1)

        if self.mode == 'linear':
            drift = self.drift_rate * t * torch.randn(B, 1, C, device=x.device)
        elif self.mode == 'exponential':
            drift = self.drift_rate * (torch.exp(t) - 1) * torch.randn(B, 1, C, device=x.device)

        return x + drift


class BurstNoise:
    """간헐적 강한 잡음"""
    def __init__(self, burst_ratio=0.15, burst_length_range=(5, 15), intensity=5.0):
        self.burst_ratio = burst_ratio
        self.burst_length_range = burst_length_range
        self.intensity = intensity

    def __call__(self, x):
        B, T, C = x.shape
        noisy_x = x.clone()

        for b in range(B):
            num_to_burst = int(T * self.burst_ratio)

            while num_to_burst > 0:
                length = np.random.randint(*self.burst_length_range)
                length = min(length, num_to_burst, T)
                start = np.random.randint(0, T - length + 1)

                burst = torch.randn(length, C, device=x.device) * self.intensity * x[b].std()
                noisy_x[b, start:start+length] += burst

                num_to_burst -= length

        return noisy_x


# ============================================================================
# 3. Missing Pattern Generators
# ============================================================================
class TemporalMaskingBatch:
    """연속 구간 마스킹"""
    def __init__(self, mask_ratio=0.3, mask_length_range=(10, 30)):
        self.mask_ratio = mask_ratio
        self.mask_length_range = mask_length_range

    def __call__(self, x):
        B, T, C = x.shape
        mask = torch.ones(B, T, dtype=torch.bool, device=x.device)

        max_len = min(self.mask_length_range[1], T)
        min_len = min(self.mask_length_range[0], T)

        for b in range(B):
            num_to_mask = int(T * self.mask_ratio)

            while num_to_mask > 0:
                length = np.random.randint(min_len, max_len + 1)
                length = min(length, num_to_mask, T)
                start = np.random.randint(0, T - length + 1)
                mask[b, start:start + length] = False
                num_to_mask -= length

        masked_x = x.clone()
        masked_x[~mask] = 0.0

        return masked_x, mask


class RandomPointDrop:
    """비연속 포인트 드롭"""
    def __init__(self, drop_ratio=0.3):
        self.drop_ratio = drop_ratio

    def __call__(self, x):
        B, T, C = x.shape
        mask = torch.rand(B, T, device=x.device) > self.drop_ratio

        dropped_x = x.clone()
        dropped_x[~mask] = 0.0

        return dropped_x, mask


class ChannelDrop:
    """채널/센서 드롭"""
    def __init__(self, drop_prob=0.3, mode='random'):
        self.drop_prob = drop_prob
        self.mode = mode

    def __call__(self, x):
        B, T, C = x.shape

        if self.mode == 'random':
            channel_mask = torch.rand(B, C, device=x.device) > self.drop_prob

        elif self.mode == 'axis':
            axis_groups = [[0,3,6], [1,4,7], [2,5,8]]
            channel_mask = torch.ones(B, C, dtype=torch.bool, device=x.device)
            for b in range(B):
                for group in axis_groups:
                    if np.random.rand() < self.drop_prob:
                        channel_mask[b, group] = False

        elif self.mode == 'sensor':
            sensor_groups = [[0,1,2], [3,4,5], [6,7,8]]
            channel_mask = torch.ones(B, C, dtype=torch.bool, device=x.device)
            for b in range(B):
                for group in sensor_groups:
                    if np.random.rand() < self.drop_prob:
                        channel_mask[b, group] = False

        dropped_x = x.clone()
        dropped_x *= channel_mask.unsqueeze(1).float()

        temporal_mask = channel_mask.any(dim=1, keepdim=True).expand(B, T)

        return dropped_x, temporal_mask


class StandardAugmentation:
    """표준 SSL 증강"""
    def __init__(self, noise_std=0.05, scale_range=(0.8, 1.2)):
        self.noise_std = noise_std
        self.scale_range = scale_range

    def __call__(self, x):
        B, T, C = x.shape

        noise = torch.randn_like(x) * self.noise_std
        augmented_x = x + noise

        scale = torch.empty(B, 1, C, device=x.device).uniform_(*self.scale_range)
        augmented_x = augmented_x * scale

        mask = torch.ones(B, T, dtype=torch.bool, device=x.device)

        return augmented_x, mask


class HybridCorruption:
    """결측 + 노이즈 동시 적용"""
    def __init__(self, missing_aug, noise_aug):
        self.missing_aug = missing_aug
        self.noise_aug = noise_aug

    def __call__(self, x):
        x_noisy = self.noise_aug(x)
        x_corrupted, mask = self.missing_aug(x_noisy)
        return x_corrupted, mask


# ============================================================================
# 4. ELK Backbone
# ============================================================================
class ELKBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=31, deploy=False):
        super().__init__()
        self.deploy = deploy
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        padding_large1 = kernel_size // 2
        kernel_size_large2 = kernel_size - 2
        padding_large2 = kernel_size_large2 // 2
        kernel_size_small1 = 5
        padding_small1 = kernel_size_small1 // 2
        kernel_size_small2 = 3
        padding_small2 = kernel_size_small2 // 2

        if deploy:
            self.reparam_conv = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=True
            )
        else:
            self.dw_large1 = nn.Conv1d(in_channels, in_channels, kernel_size,
                                       padding=padding_large1, groups=in_channels, bias=False)
            self.bn_large1 = nn.BatchNorm1d(in_channels)
            self.dw_large2 = nn.Conv1d(in_channels, in_channels, kernel_size_large2,
                                       padding=padding_large2, groups=in_channels, bias=False)
            self.bn_large2 = nn.BatchNorm1d(in_channels)
            self.dw_small1 = nn.Conv1d(in_channels, in_channels, kernel_size_small1,
                                       padding=padding_small1, groups=in_channels, bias=False)
            self.bn_small1 = nn.BatchNorm1d(in_channels)
            self.dw_small2 = nn.Conv1d(in_channels, in_channels, kernel_size_small2,
                                       padding=padding_small2, groups=in_channels, bias=False)
            self.bn_small2 = nn.BatchNorm1d(in_channels)
            self.bn_id = nn.BatchNorm1d(in_channels)

        self.pointwise = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_channels),
        )
        self.activation = nn.GELU()

    def forward(self, x):
        if self.deploy:
            x = self.reparam_conv(x)
        else:
            x1 = self.bn_large1(self.dw_large1(x))
            x2 = self.bn_large2(self.dw_large2(x))
            x3 = self.bn_small1(self.dw_small1(x))
            x4 = self.bn_small2(self.dw_small2(x))
            x5 = self.bn_id(x)
            x = x1 + x2 + x3 + x4 + x5
        x = self.activation(x)
        return self.pointwise(x)


class ELKBackbone(nn.Module):
    def __init__(self, in_channels=9, d_model=128, num_layers=6, kernel_size=31, dropout=0.1):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(in_channels, d_model, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        layers = []
        for _ in range(num_layers):
            layers.append(ELKBlock(d_model, d_model, kernel_size=kernel_size))
            layers.append(nn.Dropout(dropout))
        self.elk_layers = nn.Sequential(*layers)
        self.out_channels = d_model

    def forward(self, x):
        x = self.stem(x)
        x = self.elk_layers(x)
        return x


def masked_global_avg_pool(h, mask):
    mask_float = mask.unsqueeze(1).float()
    numerator = (h * mask_float).sum(dim=-1)
    denominator = mask_float.sum(dim=-1).clamp_min(1e-6)
    return numerator / denominator


class ELKEncoder(nn.Module):
    def __init__(self, in_channels=9, d_model=128, num_layers=6,
                 kernel_size=31, output_dim=256, dropout=0.1):
        super().__init__()
        self.backbone = ELKBackbone(in_channels, d_model, num_layers, kernel_size, dropout)
        self.projection = nn.Sequential(
            nn.Linear(d_model, output_dim),
            nn.BatchNorm1d(output_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, x, mask=None):
        x = x.transpose(1, 2)
        h = self.backbone(x)
        if mask is not None:
            h = masked_global_avg_pool(h, mask)
        else:
            h = h.mean(dim=-1)
        z = self.projection(h)
        z = F.normalize(z, dim=1)
        return z


# ============================================================================
# 5. Loss Functions
# ============================================================================
class MTCLoss(nn.Module):
    def forward(self, z_clean, z_masked):
        return F.mse_loss(z_clean, z_masked)


class SymmetricNTXentLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        B = z1.shape[0]
        device = z1.device

        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        sim_11 = (z1 @ z1.T) / self.temperature
        sim_22 = (z2 @ z2.T) / self.temperature
        sim_12 = (z1 @ z2.T) / self.temperature
        sim_21 = sim_12.T

        mask = torch.eye(B, device=device, dtype=torch.bool)
        sim_11 = sim_11.masked_fill(mask, -9e15)
        sim_22 = sim_22.masked_fill(mask, -9e15)

        logits_1 = torch.cat([sim_12, sim_11], dim=1)
        logits_2 = torch.cat([sim_21, sim_22], dim=1)
        labels = torch.arange(B, device=device)

        loss_1 = F.cross_entropy(logits_1, labels)
        loss_2 = F.cross_entropy(logits_2, labels)

        return 0.5 * (loss_1 + loss_2)


# ============================================================================
# 6. Unified SSL Framework
# ============================================================================
class UnifiedSSLFramework(nn.Module):
    def __init__(self, method='mtc_micl', in_channels=9, d_model=128,
                 num_layers=6, kernel_size=31, output_dim=256, dropout=0.1,
                 mask_ratio=0.3, temperature=0.07,
                 lambda_mtc=1.0, lambda_micl=1.0, channel_drop_mode='random'):
        super().__init__()

        self.method = method
        self.encoder = ELKEncoder(in_channels, d_model, num_layers,
                                   kernel_size, output_dim, dropout)

        if method == 'sl_only':
            self.augmentation = None
        elif method == 'ssl_wo_missing':
            self.augmentation = StandardAugmentation()
        elif method == 'random_point_drop':
            self.augmentation = RandomPointDrop(drop_ratio=mask_ratio)
        elif method == 'channel_drop':
            self.augmentation = ChannelDrop(drop_prob=mask_ratio, mode=channel_drop_mode)
        elif method == 'mtc_micl':
            self.augmentation = TemporalMaskingBatch(mask_ratio=mask_ratio)
        else:
            raise ValueError(f"Unknown method: {method}")

        self.mtc_loss = MTCLoss()
        self.micl_loss = SymmetricNTXentLoss(temperature=temperature)
        self.lambda_mtc = lambda_mtc
        self.lambda_micl = lambda_micl

    def forward(self, x):
        if self.method == 'sl_only':
            return torch.tensor(0.0, device=x.device), {'total': 0.0, 'mtc': 0.0, 'micl': 0.0}

        z_clean = self.encoder(x, mask=None)
        x_aug, mask = self.augmentation(x)
        z_aug = self.encoder(x_aug, mask=mask)

        if self.method == 'ssl_wo_missing':
            loss_mtc = torch.tensor(0.0, device=x.device)
            loss_micl = self.micl_loss(z_clean, z_aug)
            loss = loss_micl
        else:
            loss_mtc = self.mtc_loss(z_clean, z_aug)
            loss_micl = self.micl_loss(z_clean, z_aug)
            loss = self.lambda_mtc * loss_mtc + self.lambda_micl * loss_micl

        losses_dict = {
            'total': loss.item(),
            'mtc': loss_mtc.item(),
            'micl': loss_micl.item()
        }

        return loss, losses_dict

    def get_representation(self, x, mask=None):
        with torch.no_grad():
            return self.encoder(x, mask=mask)


# ============================================================================
# 7. Linear Classifier
# ============================================================================
class LinearClassifier(nn.Module):
    def __init__(self, encoder, num_classes=6):
        super().__init__()
        self.encoder = encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        with torch.no_grad():
            z = self.encoder(x, mask=None)
        return self.classifier(z)


# ============================================================================
# 8. Robustness Scenario Creation
# ============================================================================
def create_robustness_scenarios():
    """12가지 핵심 robustness 시나리오"""
    scenarios = {
        # Group A: Missing-Only
        'temporal_30': {
            'corruption': TemporalMaskingBatch(mask_ratio=0.3),
            'description': 'Temporal missing 30%',
            'group': 'Missing-Only'
        },
        'temporal_50': {
            'corruption': TemporalMaskingBatch(mask_ratio=0.5),
            'description': 'Temporal missing 50%',
            'group': 'Missing-Only'
        },
        'temporal_70': {
            'corruption': TemporalMaskingBatch(mask_ratio=0.7),
            'description': 'Temporal missing 70%',
            'group': 'Missing-Only'
        },
        'point_30': {
            'corruption': RandomPointDrop(drop_ratio=0.3),
            'description': 'Point missing 30%',
            'group': 'Missing-Only'
        },

        # Group B: Noise-Only
        'gaussian_20db': {
            'corruption': GaussianNoise(snr_db=20),
            'description': 'Gaussian noise (SNR=20dB)',
            'group': 'Noise-Only'
        },
        'gaussian_10db': {
            'corruption': GaussianNoise(snr_db=10),
            'description': 'Gaussian noise (SNR=10dB)',
            'group': 'Noise-Only'
        },
        'gaussian_5db': {
            'corruption': GaussianNoise(snr_db=5),
            'description': 'Gaussian noise (SNR=5dB)',
            'group': 'Noise-Only'
        },
        'salt_pepper_10': {
            'corruption': SaltPepperNoise(corruption_ratio=0.1),
            'description': 'Salt-pepper noise 10%',
            'group': 'Noise-Only'
        },

        # Group C: Combined
        'temporal30_gaussian10': {
            'corruption': HybridCorruption(
                TemporalMaskingBatch(0.3),
                GaussianNoise(10)
            ),
            'description': 'Temporal 30% + Gaussian 10dB',
            'group': 'Combined'
        },
        'temporal50_gaussian10': {
            'corruption': HybridCorruption(
                TemporalMaskingBatch(0.5),
                GaussianNoise(10)
            ),
            'description': 'Temporal 50% + Gaussian 10dB',
            'group': 'Combined'
        },
        'temporal30_saltpepper': {
            'corruption': HybridCorruption(
                TemporalMaskingBatch(0.3),
                SaltPepperNoise(0.1)
            ),
            'description': 'Temporal 30% + Salt-pepper 10%',
            'group': 'Combined'
        },
        'temporal50_saltpepper': {
            'corruption': HybridCorruption(
                TemporalMaskingBatch(0.5),
                SaltPepperNoise(0.1)
            ),
            'description': 'Temporal 50% + Salt-pepper 10%',
            'group': 'Combined'
        },
    }

    return scenarios


# ============================================================================
# 9. Training Functions
# ============================================================================
def train_ssl_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_mtc = 0
    total_micl = 0

    for x, _ in dataloader:
        x = x.to(device)
        optimizer.zero_grad()
        loss, losses_dict = model(x)
        if loss.item() > 0:
            loss.backward()
            optimizer.step()
        total_loss += losses_dict['total']
        total_mtc += losses_dict['mtc']
        total_micl += losses_dict['micl']

    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'mtc': total_mtc / num_batches,
        'micl': total_micl / num_batches
    }


def train_linear_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total


def evaluate_linear(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    return 100.0 * correct / total


# ============================================================================
# 10. Comprehensive Robustness Evaluation
# ============================================================================
def evaluate_comprehensive_robustness(model, test_loader, device, scenarios):
    """모든 시나리오에서 robustness 평가"""
    model.eval()
    results = {}

    print("\n" + "="*80)
    print("COMPREHENSIVE ROBUSTNESS EVALUATION")
    print("="*80)

    for scenario_name, scenario_config in scenarios.items():
        corruption = scenario_config['corruption']
        description = scenario_config['description']
        group = scenario_config['group']

        print(f"\n[{group}] {scenario_name}")
        print(f"  Description: {description}")

        correct = 0
        total = 0

        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)

                # Apply corruption
                if isinstance(corruption, HybridCorruption):
                    x_corrupted, _ = corruption(x)
                elif hasattr(corruption, '__call__'):
                    if 'Noise' in corruption.__class__.__name__:
                        x_corrupted = corruption(x)
                    else:
                        x_corrupted, _ = corruption(x)
                else:
                    x_corrupted = x

                logits = model(x_corrupted)
                pred = logits.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        accuracy = 100.0 * correct / total
        results[scenario_name] = {
            'accuracy': accuracy,
            'group': group,
            'description': description
        }

        print(f"  Accuracy: {accuracy:.2f}%")

    return results


# ============================================================================
# 11. Visualization
# ============================================================================
def plot_robustness_heatmap(all_results, save_path='robustness_heatmap.png'):
    """Robustness heatmap 생성"""
    try:
        import matplotlib.pyplot as plt
        import seaborn as sns

        methods = list(all_results.keys())
        scenarios = list(next(iter(all_results.values())).keys())

        matrix = []
        for method in methods:
            row = [all_results[method][scenario]['accuracy'] for scenario in scenarios]
            matrix.append(row)

        plt.figure(figsize=(16, 8))
        sns.heatmap(
            matrix,
            annot=True,
            fmt='.1f',
            cmap='RdYlGn',
            xticklabels=scenarios,
            yticklabels=methods,
            vmin=0,
            vmax=100,
            cbar_kws={'label': 'Accuracy (%)'}
        )
        plt.title('Robustness Comparison Across Corruption Scenarios',
                  fontsize=16, fontweight='bold')
        plt.xlabel('Corruption Scenario', fontsize=12)
        plt.ylabel('Method', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\n✓ Heatmap saved: {save_path}")
    except ImportError:
        print("\n⚠ matplotlib/seaborn not available, skipping heatmap generation")


# ============================================================================
# 12. Comprehensive Benchmark
# ============================================================================
def run_comprehensive_benchmark(data_dir, device, ssl_epochs=50, linear_epochs=50):
    """전체 벤치마크 실행"""
    print("="*80)
    print("COMPREHENSIVE BENCHMARK: ELK-MTC-MICL vs Baselines")
    print("="*80)

    train_dataset = UCIHARRawDataset(data_dir, split='train')
    test_dataset = UCIHARRawDataset(data_dir, split='test')
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    methods = [
        {'name': 'SL-Only', 'method': 'sl_only'},
        {'name': 'SSL w/o Missing', 'method': 'ssl_wo_missing'},
        {'name': 'Random Point Drop', 'method': 'random_point_drop'},
        {'name': 'Channel Drop', 'method': 'channel_drop', 'channel_mode': 'random'},
        {'name': 'MTC + MICL (Ours)', 'method': 'mtc_micl'},
    ]

    all_results = {}
    scenarios = create_robustness_scenarios()

    for config in methods:
        print(f"\n{'='*80}")
        print(f"Method: {config['name']}")
        print(f"{'='*80}")

        # SSL Pretraining
        if config['method'] != 'sl_only':
            print(f"\n[Phase 1] SSL Pretraining ({ssl_epochs} epochs)...")

            ssl_model = UnifiedSSLFramework(
                method=config['method'],
                channel_drop_mode=config.get('channel_mode', 'random')
            ).to(device)

            optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=1e-3, weight_decay=1e-5)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ssl_epochs)

            for epoch in range(ssl_epochs):
                metrics = train_ssl_epoch(ssl_model, train_loader, optimizer, device)
                scheduler.step()

                if (epoch + 1) % 10 == 0:
                    print(f"  Epoch {epoch+1}/{ssl_epochs}: Loss={metrics['loss']:.4f}")

            encoder = ssl_model.encoder
        else:
            encoder = ELKEncoder().to(device)

        # Linear Evaluation
        print(f"\n[Phase 2] Linear Evaluation ({linear_epochs} epochs)...")
        linear_model = LinearClassifier(encoder, num_classes=6).to(device)
        optimizer = torch.optim.Adam(linear_model.classifier.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        best_test_acc = 0.0
        for epoch in range(linear_epochs):
            train_loss, train_acc = train_linear_epoch(
                linear_model, train_loader, optimizer, criterion, device
            )
            test_acc = evaluate_linear(linear_model, test_loader, device)

            if test_acc > best_test_acc:
                best_test_acc = test_acc

            if (epoch + 1) % 10 == 0:
                print(f"  Epoch {epoch+1}/{linear_epochs}: Test Acc={test_acc:.2f}%")

        print(f"\n[Phase 3] Comprehensive Robustness Test...")
        robustness_results = evaluate_comprehensive_robustness(
            linear_model, test_loader, device, scenarios
        )

        all_results[config['name']] = robustness_results

        # Save checkpoint
        torch.save({
            'encoder_state_dict': encoder.state_dict(),
            'classifier_state_dict': linear_model.classifier.state_dict(),
            'config': config,
            'best_test_acc': best_test_acc,
            'robustness_results': robustness_results
        }, f"benchmark_{config['method']}.pth")

    # Summary by Group
    print(f"\n{'='*80}")
    print("ROBUSTNESS SUMMARY BY GROUP")
    print(f"{'='*80}")

    groups = ['Missing-Only', 'Noise-Only', 'Combined']

    for group in groups:
        print(f"\n{group}:")
        print("-" * 80)
        print(f"{'Method':<25} {'Avg Accuracy':<15} {'Best Scenario':<20} {'Worst Scenario':<20}")
        print("-" * 80)

        for method_name, results in all_results.items():
            group_results = {k: v['accuracy'] for k, v in results.items() if v['group'] == group}

            if group_results:
                avg_acc = np.mean(list(group_results.values()))
                best_scenario = max(group_results.items(), key=lambda x: x[1])
                worst_scenario = min(group_results.items(), key=lambda x: x[1])

                print(f"{method_name:<25} {avg_acc:>13.2f}% {best_scenario[0]:<20} {worst_scenario[0]:<20}")

    # Overall Summary
    print(f"\n{'='*80}")
    print("OVERALL ROBUSTNESS COMPARISON")
    print(f"{'='*80}")
    print(f"{'Method':<25} {'Avg All':<12} {'Missing':<12} {'Noise':<12} {'Combined':<12}")
    print("-" * 80)

    for method_name, results in all_results.items():
        all_acc = np.mean([v['accuracy'] for v in results.values()])
        missing_acc = np.mean([v['accuracy'] for v in results.values() if v['group'] == 'Missing-Only'])
        noise_acc = np.mean([v['accuracy'] for v in results.values() if v['group'] == 'Noise-Only'])
        combined_acc = np.mean([v['accuracy'] for v in results.values() if v['group'] == 'Combined'])

        print(f"{method_name:<25} {all_acc:>10.2f}% {missing_acc:>10.2f}% "
              f"{noise_acc:>10.2f}% {combined_acc:>10.2f}%")

    # Detailed Table
    print(f"\n{'='*80}")
    print("DETAILED SCENARIO RESULTS")
    print(f"{'='*80}")
    print(f"{'Scenario':<30}", end='')
    for method_name in all_results.keys():
        print(f"{method_name:<15}", end='')
    print()
    print("-" * 80)

    for scenario_name in scenarios.keys():
        print(f"{scenario_name:<30}", end='')
        for method_name in all_results.keys():
            acc = all_results[method_name][scenario_name]['accuracy']
            print(f"{acc:>13.2f}%", end='')
        print()

    # Save results
    with open('benchmark_results.json', 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"\n✓ Results saved to: benchmark_results.json")

    # Generate heatmap
    plot_robustness_heatmap(all_results)

    return all_results


# ============================================================================
# 13. Main Entry Point
# ============================================================================
def main():
    data_dir = './UCI HAR Dataset'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Device: {device}")
    print(f"PyTorch version: {torch.__version__}\n")

    # Run comprehensive benchmark
    results = run_comprehensive_benchmark(
        data_dir=data_dir,
        device=device,
        ssl_epochs=50,
        linear_epochs=50
    )

    return results


if __name__ == '__main__':
    main()