In [2]:
pip install torchprofile torchsummary

Collecting torchprofile
  Downloading torchprofile-0.0.4-py3-none-any.whl.metadata (303 bytes)
Downloading torchprofile-0.0.4-py3-none-any.whl (7.7 kB)
Installing collected packages: torchprofile
Successfully installed torchprofile-0.0.4


In [20]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pathlib import Path
import json
from collections import defaultdict
import os
import time
from sklearn.preprocessing import StandardScaler

def load_uci_har_raw(dataset_path):
    SIGNALS = ["body_acc_x", "body_acc_y", "body_acc_z", "body_gyro_x", "body_gyro_y", "body_gyro_z", "total_acc_x", "total_acc_y", "total_acc_z"]
    X_train_list, X_test_list = [], []
    for signal in SIGNALS:
        train_file = os.path.join(dataset_path, f"{signal}_train.txt")
        test_file = os.path.join(dataset_path, f"{signal}_test.txt")
        X_train_list.append(np.loadtxt(train_file, dtype=np.float32))
        X_test_list.append(np.loadtxt(test_file, dtype=np.float32))
    X_train = np.transpose(np.stack(X_train_list, axis=-1), (0, 2, 1))
    X_test = np.transpose(np.stack(X_test_list, axis=-1), (0, 2, 1))
    y_train = np.loadtxt(os.path.join(dataset_path, 'y_train.txt'), dtype=int) - 1
    y_test = np.loadtxt(os.path.join(dataset_path, 'y_test.txt'), dtype=int) - 1
    activity_names = ['Walking', 'Walking Upstairs', 'Walking Downstairs', 'Sitting', 'Standing', 'Laying']

    scaler = StandardScaler()
    X_train_flat = X_train.reshape(X_train.shape[0], -1)
    X_test_flat = X_test.reshape(X_test.shape[0], -1)
    X_train_scaled = scaler.fit_transform(X_train_flat).reshape(X_train.shape)
    X_test_scaled = scaler.transform(X_test_flat).reshape(X_test.shape)

    return X_train_scaled, y_train, X_test_scaled, y_test, activity_names

class ImprovedTemporalMasking:
    def __init__(self, mask_ratio=0.15, mask_length_range=(5, 15), noise_std=0.1):
        self.mask_ratio = mask_ratio
        self.mask_length_range = mask_length_range
        self.noise_std = noise_std

    def __call__(self, x):
        B, C, T = x.shape
        mask = torch.ones(B, T, dtype=torch.bool, device=x.device)

        max_len = min(self.mask_length_range[1], T)
        min_len = min(self.mask_length_range[0], T)

        for b in range(B):
            num_to_mask = int(T * self.mask_ratio)
            while num_to_mask > 0:
                length = np.random.randint(min_len, max_len + 1)
                length = min(length, num_to_mask, T)
                start = np.random.randint(0, T - length + 1)
                mask[b, start:start + length] = False
                num_to_mask -= length

        masked_x = x.clone()
        noise = torch.randn_like(x) * self.noise_std
        masked_x[:, :, ~mask[0]] = noise[:, :, ~mask[0]]

        return masked_x, mask

class ImprovedRandomPointDrop:
    def __init__(self, drop_ratio=0.15, noise_std=0.1):
        self.drop_ratio = drop_ratio
        self.noise_std = noise_std

    def __call__(self, x):
        B, C, T = x.shape
        mask = torch.rand(B, T, device=x.device) > self.drop_ratio

        dropped_x = x.clone()
        noise = torch.randn_like(x) * self.noise_std
        dropped_x[:, :, ~mask[0]] = noise[:, :, ~mask[0]]

        return dropped_x, mask

class ImprovedChannelDrop:
    def __init__(self, drop_prob=0.2, mode='random', noise_std=0.1):
        self.drop_prob = drop_prob
        self.mode = mode
        self.noise_std = noise_std

    def __call__(self, x):
        B, C, T = x.shape

        if self.mode == 'random':
            channel_mask = torch.rand(B, C, device=x.device) > self.drop_prob
        elif self.mode == 'axis':
            axis_groups = [[0,3,6], [1,4,7], [2,5,8]]
            channel_mask = torch.ones(B, C, dtype=torch.bool, device=x.device)
            for b in range(B):
                for group in axis_groups:
                    if np.random.rand() < self.drop_prob:
                        channel_mask[b, group] = False
        elif self.mode == 'sensor':
            sensor_groups = [[0,1,2], [3,4,5], [6,7,8]]
            channel_mask = torch.ones(B, C, dtype=torch.bool, device=x.device)
            for b in range(B):
                for group in sensor_groups:
                    if np.random.rand() < self.drop_prob:
                        channel_mask[b, group] = False

        dropped_x = x.clone()
        noise = torch.randn_like(x) * self.noise_std
        for b in range(B):
            for c in range(C):
                if not channel_mask[b, c]:
                    dropped_x[b, c, :] = noise[b, c, :]

        temporal_mask = channel_mask.any(dim=1, keepdim=True).expand(B, T)
        return dropped_x, temporal_mask

class StandardAugmentation:
    def __init__(self, noise_std=0.02, scale_range=(0.9, 1.1)):
        self.noise_std = noise_std
        self.scale_range = scale_range

    def __call__(self, x):
        B, C, T = x.shape
        noise = torch.randn_like(x) * self.noise_std
        augmented_x = x + noise
        scale = torch.empty(B, C, 1, device=x.device).uniform_(*self.scale_range)
        augmented_x = augmented_x * scale
        mask = torch.ones(B, T, dtype=torch.bool, device=x.device)
        return augmented_x, mask

class ELKBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=31, deploy=False):
        super().__init__()
        self.deploy = deploy
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        padding_large1 = kernel_size // 2
        kernel_size_large2 = kernel_size - 2
        padding_large2 = kernel_size_large2 // 2
        kernel_size_small1 = 5
        padding_small1 = kernel_size_small1 // 2
        kernel_size_small2 = 3
        padding_small2 = kernel_size_small2 // 2

        if deploy:
            self.reparam_conv = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=True
            )
        else:
            self.dw_large1 = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=False
            )
            self.bn_large1 = nn.BatchNorm1d(in_channels)
            self.dw_large2 = nn.Conv1d(
                in_channels, in_channels, kernel_size_large2,
                padding=padding_large2, groups=in_channels, bias=False
            )
            self.bn_large2 = nn.BatchNorm1d(in_channels)
            self.dw_small1 = nn.Conv1d(
                in_channels, in_channels, kernel_size_small1,
                padding=padding_small1, groups=in_channels, bias=False
            )
            self.bn_small1 = nn.BatchNorm1d(in_channels)
            self.dw_small2 = nn.Conv1d(
                in_channels, in_channels, kernel_size_small2,
                padding=padding_small2, groups=in_channels, bias=False
            )
            self.bn_small2 = nn.BatchNorm1d(in_channels)
            self.bn_id = nn.BatchNorm1d(in_channels)

        self.pointwise = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_channels),
        )
        self.activation = nn.GELU()

    def forward(self, x):
        if self.deploy:
            x = self.reparam_conv(x)
        else:
            x1 = self.bn_large1(self.dw_large1(x))
            x2 = self.bn_large2(self.dw_large2(x))
            x3 = self.bn_small1(self.dw_small1(x))
            x4 = self.bn_small2(self.dw_small2(x))
            x5 = self.bn_id(x)
            x = x1 + x2 + x3 + x4 + x5
        x = self.activation(x)
        return self.pointwise(x)

class LightweightELKBackbone(nn.Module):
    def __init__(self, in_channels=9, d_model=128, num_layers=1, kernel_size=31, dropout=0.1):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(in_channels, d_model, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )

        self.elk_block = ELKBlock(d_model, d_model, kernel_size=kernel_size)
        self.dropout = nn.Dropout(dropout)
        self.out_channels = d_model

    def forward(self, x):
        x = self.stem(x)
        x = self.elk_block(x)
        x = self.dropout(x)
        return x

def improved_masked_pooling(h, mask, pool_type='attention'):
    if pool_type == 'attention':
        B, C, T = h.shape
        mask_float = mask.float().unsqueeze(1)

        attention_weights = torch.softmax(h.mean(dim=1, keepdim=True), dim=-1)
        attention_weights = attention_weights * mask_float
        attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8)

        pooled = (h * attention_weights).sum(dim=-1)
        return pooled

    elif pool_type == 'weighted_avg':
        mask_float = mask.unsqueeze(1).float()
        numerator = (h * mask_float).sum(dim=-1)
        denominator = mask_float.sum(dim=-1).clamp_min(1e-6)
        return numerator / denominator

    else:
        return h.mean(dim=-1)

class ImprovedELKEncoder(nn.Module):
    def __init__(self, in_channels=9, d_model=128, num_layers=1,
                 kernel_size=31, output_dim=256, dropout=0.1, pool_type='attention'):
        super().__init__()
        self.backbone = LightweightELKBackbone(in_channels, d_model, num_layers, kernel_size, dropout)
        self.pool_type = pool_type
        self.projection = nn.Sequential(
            nn.Linear(d_model, output_dim),
            nn.LayerNorm(output_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, x, mask=None):
        h = self.backbone(x)
        if mask is not None:
            h = improved_masked_pooling(h, mask, self.pool_type)
        else:
            h = h.mean(dim=-1)
        z = self.projection(h)
        z = F.normalize(z, dim=1)
        return z

class ImprovedMTCLoss(nn.Module):
    def __init__(self, temperature=0.1, use_cosine=True):
        super().__init__()
        self.temperature = temperature
        self.use_cosine = use_cosine

    def forward(self, z_clean, z_masked):
        if self.use_cosine:
            sim = F.cosine_similarity(z_clean, z_masked, dim=1)
            loss = -torch.log(torch.sigmoid(sim / self.temperature)).mean()
        else:
            loss = F.mse_loss(z_clean, z_masked)
        return loss

class ImprovedNTXentLoss(nn.Module):
    def __init__(self, temperature=0.07, use_hard_negatives=True):
        super().__init__()
        self.temperature = temperature
        self.use_hard_negatives = use_hard_negatives

    def forward(self, z1, z2):
        B = z1.shape[0]
        device = z1.device

        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        sim_11 = (z1 @ z1.T) / self.temperature
        sim_22 = (z2 @ z2.T) / self.temperature
        sim_12 = (z1 @ z2.T) / self.temperature
        sim_21 = sim_12.T

        mask = torch.eye(B, device=device, dtype=torch.bool)
        sim_11 = sim_11.masked_fill(mask, -9e15)
        sim_22 = sim_22.masked_fill(mask, -9e15)

        if self.use_hard_negatives:
            hard_neg_weight = 2.0
            sim_11 = sim_11 * hard_neg_weight
            sim_22 = sim_22 * hard_neg_weight

        logits_1 = torch.cat([sim_12, sim_11], dim=1)
        logits_2 = torch.cat([sim_21, sim_22], dim=1)
        labels = torch.arange(B, device=device)

        loss_1 = F.cross_entropy(logits_1, labels)
        loss_2 = F.cross_entropy(logits_2, labels)

        return 0.5 * (loss_1 + loss_2)

class ImprovedSSLFramework(nn.Module):
    def __init__(self, method='mtc_micl', in_channels=9, d_model=128,
                 num_layers=1, kernel_size=31, output_dim=256, dropout=0.1,
                 mask_ratio=0.15, temperature=0.07,
                 lambda_mtc=1.0, lambda_micl=1.0, channel_drop_mode='random',
                 pool_type='attention'):
        super().__init__()

        self.method = method
        self.encoder = ImprovedELKEncoder(in_channels, d_model, num_layers,
                                        kernel_size, output_dim, dropout, pool_type)

        if method == 'sl_only':
            self.augmentation = None
        elif method == 'ssl_wo_missing':
            self.augmentation = StandardAugmentation()
        elif method == 'random_point_drop':
            self.augmentation = ImprovedRandomPointDrop(drop_ratio=mask_ratio)
        elif method == 'channel_drop':
            self.augmentation = ImprovedChannelDrop(drop_prob=mask_ratio, mode=channel_drop_mode)
        elif method == 'mtc_micl':
            self.augmentation = ImprovedTemporalMasking(mask_ratio=mask_ratio)
        else:
            raise ValueError(f"Unknown method: {method}")

        self.mtc_loss = ImprovedMTCLoss(temperature=0.1)
        self.micl_loss = ImprovedNTXentLoss(temperature=temperature)
        self.lambda_mtc = lambda_mtc
        self.lambda_micl = lambda_micl

    def forward(self, x):
        if self.method == 'sl_only':
            return torch.tensor(0.0, device=x.device), {'total': 0.0, 'mtc': 0.0, 'micl': 0.0}

        z_clean = self.encoder(x, mask=None)
        x_aug, mask = self.augmentation(x)
        z_aug = self.encoder(x_aug, mask=mask)

        if self.method == 'ssl_wo_missing':
            loss_mtc = torch.tensor(0.0, device=x.device)
            loss_micl = self.micl_loss(z_clean, z_aug)
            loss = loss_micl
        else:
            loss_mtc = self.mtc_loss(z_clean, z_aug)
            loss_micl = self.micl_loss(z_clean, z_aug)
            loss = self.lambda_mtc * loss_mtc + self.lambda_micl * loss_micl

        losses_dict = {
            'total': loss.item(),
            'mtc': loss_mtc.item(),
            'micl': loss_micl.item()
        }

        return loss, losses_dict

    def get_representation(self, x, mask=None):
        with torch.no_grad():
            return self.encoder(x, mask=mask)

class TwoStageLinearClassifier(nn.Module):
    def __init__(self, encoder, num_classes=6):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

        self._freeze_encoder()

    def _freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def _unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

    def forward(self, x):
        z = self.encoder(x, mask=None)
        return self.classifier(z)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def calculate_flops(model, input_shape=(1, 9, 128), device='cuda'):
    model.eval()
    x = torch.randn(input_shape).to(device)

    def conv_flop_count(input_shape, weight_shape, stride=1, padding=0):
        batch_size, in_channels, input_length = input_shape
        out_channels, in_channels, kernel_size = weight_shape
        output_length = (input_length + 2 * padding - kernel_size) // stride + 1
        return batch_size * out_channels * output_length * in_channels * kernel_size

    def linear_flop_count(input_features, output_features):
        return input_features * output_features

    total_flops = 0

    def flop_hook(module, input, output):
        nonlocal total_flops
        if isinstance(module, nn.Conv1d):
            input_shape = input[0].shape
            weight_shape = module.weight.shape
            stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride
            padding = module.padding[0] if isinstance(module.padding, tuple) else module.padding
            flops = conv_flop_count(input_shape, weight_shape, stride, padding)
            total_flops += flops
        elif isinstance(module, nn.Linear):
            input_features = input[0].shape[-1]
            output_features = module.out_features
            batch_size = input[0].shape[0]
            flops = batch_size * linear_flop_count(input_features, output_features)
            total_flops += flops

    hooks = []
    for module in model.modules():
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            hooks.append(module.register_forward_hook(flop_hook))

    with torch.no_grad():
        _ = model(x)

    for hook in hooks:
        hook.remove()

    return total_flops

def measure_inference_time(model, dataloader, device, num_batches=10):
    model.eval()
    total_time = 0
    total_samples = 0

    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            if i >= num_batches:
                break
            x = x.to(device)

            torch.cuda.synchronize() if device.type == 'cuda' else None
            start_time = time.time()
            _ = model(x)
            torch.cuda.synchronize() if device.type == 'cuda' else None
            end_time = time.time()

            total_time += end_time - start_time
            total_samples += x.size(0)

    return total_time / total_samples * 1000

def train_ssl_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_mtc = 0
    total_micl = 0

    for x, _ in dataloader:
        x = x.to(device)
        optimizer.zero_grad()
        loss, losses_dict = model(x)
        if loss.item() > 0:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        total_loss += losses_dict['total']
        total_mtc += losses_dict['mtc']
        total_micl += losses_dict['micl']

    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'mtc': total_mtc / num_batches,
        'micl': total_micl / num_batches
    }

def train_two_stage_linear(model, train_loader, test_loader, device,
                          stage1_epochs=75, stage2_epochs=25,
                          stage1_lr=1e-3, stage2_lr=1e-5):
    criterion = nn.CrossEntropyLoss()
    best_test_acc = 0.0

    print(f"  Stage 1: Frozen encoder + classifier training ({stage1_epochs} epochs)")
    model._freeze_encoder()
    optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=stage1_lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=stage1_epochs)

    for epoch in range(stage1_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

        scheduler.step()
        train_acc = 100.0 * correct / total
        test_acc = evaluate_linear(model, test_loader, device)
        best_test_acc = max(best_test_acc, test_acc)

        if (epoch + 1) % 15 == 0:
            print(f"    Epoch {epoch+1}/{stage1_epochs}: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

    print(f"  Stage 2: Fine-tuning entire model ({stage2_epochs} epochs)")
    model._unfreeze_encoder()

    param_groups = [
        {'params': model.encoder.parameters(), 'lr': stage2_lr},
        {'params': model.classifier.parameters(), 'lr': stage1_lr}
    ]
    optimizer = torch.optim.AdamW(param_groups, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=stage2_epochs)

    for epoch in range(stage2_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

        scheduler.step()
        train_acc = 100.0 * correct / total
        test_acc = evaluate_linear(model, test_loader, device)
        best_test_acc = max(best_test_acc, test_acc)

        if (epoch + 1) % 5 == 0:
            print(f"    Epoch {epoch+1}/{stage2_epochs}: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

    return best_test_acc

def evaluate_linear(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    return 100.0 * correct / total

def evaluate_missing_robustness_improved(model, dataloader, device, method_name, missing_ratios=[0.0, 0.15, 0.3, 0.5]):
    model.eval()
    results = {}

    for ratio in missing_ratios:
        if method_name == 'ssl_wo_missing' and ratio > 0:
            results[ratio] = 0.0
            continue

        if method_name == 'random_point_drop':
            masking = ImprovedRandomPointDrop(drop_ratio=ratio)
        elif 'channel_drop' in method_name:
            mode = 'sensor' if 'sensor' in method_name else 'random'
            masking = ImprovedChannelDrop(drop_prob=ratio, mode=mode)
        else:
            masking = ImprovedTemporalMasking(mask_ratio=ratio)

        correct = 0
        total = 0

        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(device), y.to(device)

                if ratio > 0:
                    x_masked, mask = masking(x)
                    logits = model(x_masked)
                else:
                    logits = model(x)

                pred = logits.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        results[ratio] = 100.0 * correct / total

    return results

def run_improved_benchmark(data_dir, device, ssl_epochs=150, batch_size=128):
    print("="*80)
    print("IMPROVED BENCHMARK: Lightweight ELK with Two-Stage Training")
    print("="*80)

    X_train, y_train, X_test, y_test, activity_names = load_uci_har_raw(data_dir)

    train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
    test_dataset = TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    methods = [
        {'name': 'SL-Only', 'method': 'sl_only', 'desc': 'Supervised learning only'},
        {'name': 'SSL w/o Missing', 'method': 'ssl_wo_missing', 'desc': 'Standard contrastive'},
        {'name': 'Random Point Drop', 'method': 'random_point_drop', 'desc': 'Non-contiguous missing'},
        {'name': 'Channel Drop (Random)', 'method': 'channel_drop', 'desc': 'Random channel missing', 'channel_mode': 'random'},
        {'name': 'Channel Drop (Sensor)', 'method': 'channel_drop', 'desc': 'Sensor-wise missing', 'channel_mode': 'sensor'},
        {'name': 'MTC + MICL (Ours)', 'method': 'mtc_micl', 'desc': 'Improved temporal masking'},
    ]

    results = defaultdict(dict)

    for config in methods:
        print(f"\n{'='*80}")
        print(f"Method: {config['name']}")
        print(f"Description: {config['desc']}")
        print(f"{'='*80}")

        if config['method'] != 'sl_only':
            print(f"\n[Phase 1] SSL Pretraining ({ssl_epochs} epochs)...")

            ssl_model = ImprovedSSLFramework(
                method=config['method'],
                channel_drop_mode=config.get('channel_mode', 'random'),
                mask_ratio=0.15,
                num_layers=1
            ).to(device)

            optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=5e-4, weight_decay=1e-4)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ssl_epochs)

            for epoch in range(ssl_epochs):
                metrics = train_ssl_epoch(ssl_model, train_loader, optimizer, device)
                scheduler.step()

                if (epoch + 1) % 25 == 0:
                    print(f"  Epoch {epoch+1}/{ssl_epochs}: Loss={metrics['loss']:.4f}, "
                          f"MTC={metrics['mtc']:.4f}, MICL={metrics['micl']:.4f}")

            encoder = ssl_model.encoder
        else:
            encoder = ImprovedELKEncoder(num_layers=1).to(device)

        if config['method'] == 'sl_only':
            print(f"\n[Phase 2] Supervised Learning (100 epochs)...")
            linear_model = TwoStageLinearClassifier(encoder, num_classes=6).to(device)
            linear_model._unfreeze_encoder()

            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.AdamW(linear_model.parameters(), lr=1e-3, weight_decay=1e-4)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

            best_test_acc = 0.0
            for epoch in range(100):
                linear_model.train()
                total_loss = 0
                correct = 0
                total = 0

                for x, y in train_loader:
                    x, y = x.to(device), y.to(device)
                    optimizer.zero_grad()
                    logits = linear_model(x)
                    loss = criterion(logits, y)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(linear_model.parameters(), max_norm=1.0)
                    optimizer.step()

                    total_loss += loss.item()
                    pred = logits.argmax(dim=1)
                    correct += (pred == y).sum().item()
                    total += y.size(0)

                scheduler.step()
                train_acc = 100.0 * correct / total
                test_acc = evaluate_linear(linear_model, test_loader, device)
                best_test_acc = max(best_test_acc, test_acc)

                if (epoch + 1) % 20 == 0:
                    print(f"  Epoch {epoch+1}/100: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
        else:
            print(f"\n[Phase 2] Two-Stage Linear Evaluation...")
            linear_model = TwoStageLinearClassifier(encoder, num_classes=6).to(device)

            best_test_acc = train_two_stage_linear(
                linear_model, train_loader, test_loader, device,
                stage1_epochs=75, stage2_epochs=25,
                stage1_lr=1e-3, stage2_lr=1e-5
            )

        results[config['name']]['test_accuracy'] = best_test_acc

        total_params = count_parameters(linear_model)
        encoder_params = count_parameters(linear_model.encoder)
        total_flops = calculate_flops(linear_model, input_shape=(1, 9, 128), device=device)
        inference_time = measure_inference_time(linear_model, test_loader, device)

        results[config['name']]['total_params'] = total_params
        results[config['name']]['encoder_params'] = encoder_params
        results[config['name']]['total_flops'] = total_flops
        results[config['name']]['inference_time_ms'] = inference_time

        print(f"\n[Phase 3] Missing Robustness Test...")
        missing_results = evaluate_missing_robustness_improved(
            linear_model, test_loader, device, config['name'],
            missing_ratios=[0.0, 0.15, 0.3, 0.5]
        )
        results[config['name']]['missing_robustness'] = missing_results

        for ratio, acc in missing_results.items():
            print(f"  Missing Ratio {ratio:.2f}: {acc:.2f}%")

        print(f"\n[Model Statistics]")
        print(f"  Total Parameters: {total_params:,}")
        print(f"  Encoder Parameters: {encoder_params:,}")
        print(f"  Total FLOPs: {total_flops/1e6:.2f}M")
        print(f"  Inference Time: {inference_time:.2f} ms/sample")

    print(f"\n{'='*80}")
    print("IMPROVED BENCHMARK SUMMARY")
    print(f"{'='*80}")
    print(f"{'Method':<25} {'Clean Acc':<10} {'15% Miss':<10} {'30% Miss':<10} {'50% Miss':<10} {'Params':<8} {'FLOPs':<8} {'Time(ms)':<8}")
    print("-"*105)

    for method_name, result in results.items():
        clean_acc = result['missing_robustness'][0.0]
        miss_15 = result['missing_robustness'].get(0.15, 0.0)
        miss_30 = result['missing_robustness'].get(0.3, 0.0)
        miss_50 = result['missing_robustness'].get(0.5, 0.0)
        params = result['total_params'] // 1000
        flops = result['total_flops'] / 1e6
        inf_time = result['inference_time_ms']

        print(f"{method_name:<25} {clean_acc:>8.1f}% {miss_15:>8.1f}% "
              f"{miss_30:>8.1f}% {miss_50:>8.1f}% {params:>6}K {flops:>6.1f}M {inf_time:>6.1f}")

    with open('improved_benchmark_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\n✓ Results saved to: improved_benchmark_results.json")

    return results

def main():
    data_dir = '/content/'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Device: {device}")
    print(f"PyTorch version: {torch.__version__}\n")

    results = run_improved_benchmark(
        data_dir=data_dir,
        device=device,
        ssl_epochs=150,
        batch_size=128
    )

    return results

if __name__ == '__main__':
    main()

Device: cuda
PyTorch version: 2.8.0+cu126

IMPROVED BENCHMARK: Lightweight ELK with Two-Stage Training

Method: SL-Only
Description: Supervised learning only

[Phase 2] Supervised Learning (100 epochs)...




  Epoch 20/100: Train Acc=96.18%, Test Acc=91.52%
  Epoch 40/100: Train Acc=96.68%, Test Acc=92.47%
  Epoch 60/100: Train Acc=97.05%, Test Acc=92.77%
  Epoch 80/100: Train Acc=97.59%, Test Acc=92.57%
  Epoch 100/100: Train Acc=97.95%, Test Acc=93.04%

[Phase 3] Missing Robustness Test...
  Missing Ratio 0.00: 93.04%
  Missing Ratio 0.15: 90.09%
  Missing Ratio 0.30: 83.54%
  Missing Ratio 0.50: 70.55%

[Model Statistics]
  Total Parameters: 163,334
  Encoder Parameters: 129,664
  Total FLOPs: 3.79M
  Inference Time: 0.02 ms/sample

Method: SSL w/o Missing
Description: Standard contrastive

[Phase 1] SSL Pretraining (150 epochs)...
  Epoch 25/150: Loss=1.9001, MTC=0.0000, MICL=1.9001
  Epoch 50/150: Loss=1.4677, MTC=0.0000, MICL=1.4677
  Epoch 75/150: Loss=1.2418, MTC=0.0000, MICL=1.2418
  Epoch 100/150: Loss=1.0426, MTC=0.0000, MICL=1.0426
  Epoch 125/150: Loss=0.9185, MTC=0.0000, MICL=0.9185
  Epoch 150/150: Loss=0.9490, MTC=0.0000, MICL=0.9490

[Phase 2] Two-Stage Linear Evaluation..