In [20]:
# ============================================================
# UCI-HAR (Inertial Signals, raw 9ch) Training & Evaluation
# InceptionMK + Rotation SSL + (optional) Rotation-Consistency + RoCo-TTA
# + Proto-Orbit Contrast (class-prototype-based rotation-invariant loss)
# FIXED VERSION - Resolving gradient flow issues
# ============================================================

import os
import io
import math
import zipfile
import urllib.request
from pathlib import Path
from typing import Tuple

import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ----------------------------
# 0) Utilities: download & load UCI-HAR (raw inertial)
# ----------------------------
UCI_URL = "https://archive.ics.uci.edu/static/public/240/human+activity+recognition+using+smartphones.zip"

def download_and_extract_uci(root: str):
    root = Path(root)
    data_dir = root / "UCI_HAR_Dataset"
    if data_dir.exists():
        return data_dir

    root.mkdir(parents=True, exist_ok=True)
    zip_path = root / "uci_har.zip"
    print("[UCI] Downloading dataset...")
    urllib.request.urlretrieve(UCI_URL, zip_path)

    print("[UCI] Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zf:
        zf.extractall(root)

    # The archive extracts a folder "UCI HAR Dataset"
    # unify name to UCI_HAR_Dataset
    original = root / "UCI HAR Dataset"
    if original.exists():
        original.rename(data_dir)
    zip_path.unlink(missing_ok=True)
    print("[UCI] Ready at:", str(data_dir))
    return data_dir

def _load_signal_file(path: Path) -> np.ndarray:
    # Each file: N lines, each line has 128 floats separated by space
    arr = np.loadtxt(path, dtype=np.float32)
    # shape: (N, 128)
    return arr

def load_uci_inertial(data_dir: Path):
    """
    Returns:
      X_train: (N_train, 128, 9), y_train: (N_train,)
      X_test:  (N_test, 128, 9), y_test:  (N_test,)
      label_names: list of 6 activities
    """
    base = data_dir

    # 9 inertial signals
    train_inertial = base / "train" / "Inertial Signals"
    test_inertial  = base / "test" / "Inertial Signals"

    files = [
        "total_acc_x", "total_acc_y", "total_acc_z",
        "body_acc_x",  "body_acc_y",  "body_acc_z",
        "body_gyro_x", "body_gyro_y", "body_gyro_z"
    ]
    # Build arrays by stacking channel-wise
    Xtr_list = []
    Xte_list = []
    for name in files:
        ftr = train_inertial / f"{name}_train.txt"
        fte = test_inertial  / f"{name}_test.txt"
        Xtr_list.append(_load_signal_file(ftr))  # (Ntr,128)
        Xte_list.append(_load_signal_file(fte))  # (Nte,128)

    # Stack -> (C, N, 128) then transpose
    X_train = np.stack(Xtr_list, axis=2)  # (Ntr,128,9)
    X_test  = np.stack(Xte_list, axis=2)  # (Nte,128,9)

    # Labels: 1..6
    y_train = np.loadtxt(base / "train" / "y_train.txt", dtype=np.int64).ravel() - 1
    y_test  = np.loadtxt(base / "test" / "y_test.txt", dtype=np.int64).ravel() - 1

    # Label names
    label_names = ["WALKING","WALKING_UPSTAIRS","WALKING_DOWNSTAIRS","SITTING","STANDING","LAYING"]

    return X_train, y_train, X_test, y_test, label_names


# ----------------------------
# 1) Rotation utilities (3-axis group-wise)
# ----------------------------
def _batched_rotation_matrix(angles):  # angles: (B, 3) in radians: rx, ry, rz
    B = angles.shape[0]
    rx, ry, rz = angles[:, 0], angles[:, 1], angles[:, 2]

    cx, sx = torch.cos(rx), torch.sin(rx)
    cy, sy = torch.cos(ry), torch.sin(ry)
    cz, sz = torch.cos(rz), torch.sin(rz)

    Rx = torch.zeros(B, 3, 3, device=angles.device, dtype=angles.dtype)
    Rx[:, 0, 0] = 1.
    Rx[:, 1, 1] = cx; Rx[:, 1, 2] = -sx
    Rx[:, 2, 1] = sx; Rx[:, 2, 2] = cx

    Ry = torch.zeros(B, 3, 3, device=angles.device, dtype=angles.dtype)
    Ry[:, 0, 0] = cy;  Ry[:, 0, 2] = sy
    Ry[:, 1, 1] = 1.
    Ry[:, 2, 0] = -sy; Ry[:, 2, 2] = cy

    Rz = torch.zeros(B, 3, 3, device=angles.device, dtype=angles.dtype)
    Rz[:, 0, 0] = cz; Rz[:, 0, 1] = -sz
    Rz[:, 1, 0] = sz; Rz[:, 1, 1] = cz
    Rz[:, 2, 2] = 1.

    R = Rz @ Ry @ Rx
    return R  # (B, 3, 3)

def apply_random_rotation_xyz(x, max_deg=20.0, group_size=3):
    """
    x: (B, T, C)
    C must be multiple of 3. UCI inertial: 9ch -> OK.
    """
    B, T, C = x.shape
    assert C % group_size == 0
    G = C // group_size

    max_rad = max_deg * math.pi / 180.0
    angles = (torch.rand(B, 3, device=x.device, dtype=x.dtype) * 2 - 1) * max_rad
    R = _batched_rotation_matrix(angles)  # (B,3,3)

    x_rot = x.clone()
    for g in range(G):
        seg = x[:, :, g*group_size:(g+1)*group_size]      # (B, T, 3)
        seg2 = seg.reshape(B*T, 3)
        Rb = R.repeat_interleave(T, dim=0)                # (B*T,3,3)
        seg_rot = (Rb @ seg2.unsqueeze(-1)).squeeze(-1)   # (B*T,3)
        x_rot[:, :, g*group_size:(g+1)*group_size] = seg_rot.view(B, T, 3)
    return x_rot

def make_rotation_class(x, num_rotations=4):
    """Fixed version with proper tensor handling"""
    device = x.device
    B, T, C = x.shape
    angles_deg = torch.tensor([0., 90., 180., 270.], device=device, dtype=x.dtype)
    labels = torch.randint(0, len(angles_deg), (B,), device=device, dtype=torch.long)
    rad = angles_deg[labels] * math.pi / 180.0

    angles = torch.stack([torch.zeros_like(rad), torch.zeros_like(rad), rad], dim=1)  # (B,3)
    R = _batched_rotation_matrix(angles)  # (B,3,3)

    x_rot = x.clone()
    G = C // 3
    for g in range(G):
        seg = x[:, :, g*3:(g+1)*3]           # (B,T,3)
        seg2 = seg.reshape(B*T, 3)
        Rb = R.repeat_interleave(T, dim=0)   # (B*T,3,3)
        seg_rot = (Rb @ seg2.unsqueeze(-1)).squeeze(-1)
        x_rot[:, :, g*3:(g+1)*3] = seg_rot.view(B, T, 3)

    return x_rot, labels


# ----------------------------
# 2) Model (InceptionMK + adapter + TTA + Prototypes)
# ----------------------------
class DSConv1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=kernel_size // 2, groups=in_channels, bias=False)
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_channels)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        x = self.act(x)
        return x

class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.MaxPool1d(kernel_size=3, stride=1, padding=1),
            DSConv1D(in_channels, out_channels, kernel_size=1)
        )
        self.branch2 = DSConv1D(in_channels, out_channels, kernel_size=1)
        self.branch3 = DSConv1D(in_channels, out_channels, kernel_size=3)
        self.branch4 = DSConv1D(in_channels, out_channels, kernel_size=5)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = torch.cat([
            self.branch1(x), self.branch2(x),
            self.branch3(x), self.branch4(x)
        ], dim=1)
        return self.relu(out)

class MultiKernelBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = DSConv1D(in_channels, out_channels, kernel_size=1)
        self.branch2 = DSConv1D(in_channels, out_channels, kernel_size=3)
        self.branch3 = DSConv1D(in_channels, out_channels, kernel_size=5)
        self.branch4 = DSConv1D(in_channels, out_channels, kernel_size=7)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = torch.cat([
            self.branch1(x), self.branch2(x),
            self.branch3(x), self.branch4(x)
        ], dim=1)
        return self.relu(out)

class FeatureAdapter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim))
        self.beta  = nn.Parameter(torch.zeros(1, dim))
    def forward(self, z):
        return z * self.gamma + self.beta

class InceptionMK(nn.Module):
    def __init__(self, input_channels=9, stem_out=64, block_out=32,
                 embedding_dim=128, num_classes=6, num_rotations=4):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(input_channels, stem_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(stem_out),
            nn.ReLU(inplace=True)
        )
        self.inception = InceptionBlock(stem_out, block_out)
        self.mk_block  = MultiKernelBlock(block_out * 4, block_out)

        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc  = nn.Linear(block_out * 4, embedding_dim)
        self.adapter = FeatureAdapter(embedding_dim)

        self.head_act = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, 128),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        self.head_rot = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, 128),
            nn.Dropout(0.3),
            nn.Linear(128, num_rotations)
        )

        # Proto-Orbit: class prototypes (EMA)
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.register_buffer("prototypes", torch.zeros(num_classes, embedding_dim))
        self.register_buffer("proto_counts", torch.zeros(num_classes))
        self.proto_m = 0.99

    def forward(self, x):  # x: (B,T,C)
        x = x.transpose(1, 2)           # -> (B,C,T)
        x = self.stem(x)
        x = self.inception(x)
        x = self.mk_block(x)
        x = self.gap(x).squeeze(-1)     # (B, 4*block_out)
        z = self.fc(x)                  # (B, D)
        z = self.adapter(z)
        logits_act = self.head_act(z)
        logits_rot = self.head_rot(z)
        return logits_act, logits_rot, z

    @torch.no_grad()
    def update_prototypes(self, z, y):
        """Fixed prototype update with proper initialization"""
        if z.size(0) == 0:
            return

        z_n = F.normalize(z, dim=-1)
        unique_classes = y.unique()

        for c in unique_classes:
            mask = (y == c)
            if mask.any():
                z_class = z_n[mask].mean(dim=0)
                count = self.proto_counts[c]

                if count == 0:
                    # First update: direct assignment
                    self.prototypes[c] = z_class
                    self.proto_counts[c] = 1
                else:
                    # EMA update
                    self.prototypes[c] = self.proto_m * self.prototypes[c] + (1 - self.proto_m) * z_class
                    self.proto_counts[c] += 1


# ----------------------------
# 2.5) Rotation-consistency (optional)
# ----------------------------
def rotation_consistency_loss(model, x, K=2, max_deg=15.0, temperature=1.0):
    """Fixed rotation consistency loss"""
    with torch.no_grad():
        base_logits, _, _ = model.forward(x)
        base_logits = base_logits / temperature
        p_base = F.softmax(base_logits, dim=-1)

    kl_losses = []
    for _ in range(K):
        xr = apply_random_rotation_xyz(x, max_deg=max_deg)
        logits_r, _, _ = model.forward(xr)
        p_r = F.log_softmax(logits_r / temperature, dim=-1)
        kl = F.kl_div(p_r, p_base, reduction='batchmean', log_target=False)
        kl_losses.append(kl)

    return sum(kl_losses) / len(kl_losses)


@torch.no_grad()
def _entropy(p):
    return -(p * (p.clamp_min(1e-8).log())).sum(dim=-1)

@torch.no_grad()
def roco_tta_predict(model, x, steps=3, K=4, lr=5e-3, max_deg=20.0, temperature=1.0, weight_cons=1.0):
    was_training = model.training
    model.eval()

    # Only adapt the adapter parameters
    for p in model.parameters():
        p.requires_grad_(False)
    for p in model.adapter.parameters():
        p.requires_grad_(True)

    opt = torch.optim.Adam(model.adapter.parameters(), lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        lo0, _, _ = model.forward(x)
        p0 = F.softmax(lo0 / temperature, dim=-1)
        ent = _entropy(p0).mean()

        kl_losses = []
        for _k in range(K):
            xr = apply_random_rotation_xyz(x, max_deg=max_deg)
            lor, _, _ = model.forward(xr)
            pr = F.log_softmax(lor / temperature, dim=-1)
            kl = F.kl_div(pr, p0.detach(), reduction='batchmean', log_target=False)
            kl_losses.append(kl)

        kl_avg = sum(kl_losses) / len(kl_losses)
        loss = ent + weight_cons * kl_avg
        loss.backward()
        opt.step()

    # Ensemble prediction
    probs = []
    for _k in range(K):
        xr = apply_random_rotation_xyz(x, max_deg=max_deg)
        lor, _, _ = model.forward(xr)
        probs.append(F.softmax(lor, dim=-1))
    probs = torch.stack(probs, dim=0).mean(dim=0)

    # Restore gradient requirements
    for p in model.parameters():
        p.requires_grad_(True)
    if was_training:
        model.train()
    return probs


# ----------------------------
# 2.7) Proto-Orbit Contrast (Fixed version)
# ----------------------------
def proto_orbit_loss(model, x, y, K=2, max_deg=20.0, temperature=0.2):
    """Fixed proto-orbit loss with proper gradient handling"""
    proto = model.prototypes.detach()
    proto_n = F.normalize(proto, dim=-1)

    # Check if prototypes are initialized
    if torch.all(proto == 0):
        # Return zero loss but maintain gradient flow
        dummy_loss = torch.zeros(1, device=x.device, requires_grad=True)
        return dummy_loss.sum()

    losses = []
    for _ in range(K):
        xr = apply_random_rotation_xyz(x, max_deg=max_deg)
        _, _, z_r = model(xr)
        z_r = F.normalize(z_r, dim=-1)
        logits = (z_r @ proto_n.t()) / temperature
        loss = F.cross_entropy(logits, y)
        losses.append(loss)

    return sum(losses) / len(losses)


# ----------------------------
# 3) Dataset & Dataloader (Fixed version)
# ----------------------------
class UCIHARInertial(Dataset):
    def __init__(self, X, y, scaler: StandardScaler=None, train: bool=True,
                 rot_ssl: bool=True, num_rotations: int=4, p_rotate_ssl: float=0.5):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.train = train
        self.rot_ssl = rot_ssl
        self.num_rot = num_rotations
        self.p_rotate_ssl = p_rotate_ssl

        # Standardization per-channel
        self.scaler = scaler
        if self.scaler is not None:
            N, T, C = self.X.shape
            X2 = self.X.reshape(-1, C)
            X2 = self.scaler.transform(X2)
            self.X = X2.reshape(N, T, C)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])
        y = torch.tensor(self.y[idx], dtype=torch.long)

        # Default values
        rot_label = torch.tensor(0, dtype=torch.long)
        x_ssl = x.clone()
        apply_rot = False

        if self.train and self.rot_ssl and np.random.rand() < self.p_rotate_ssl:
            apply_rot = True
            x_ssl_batch, rot_label_batch = make_rotation_class(x.unsqueeze(0), num_rotations=self.num_rot)
            x_ssl = x_ssl_batch.squeeze(0)
            rot_label = rot_label_batch.squeeze(0) if rot_label_batch.dim() > 0 else rot_label_batch

        return x_ssl, y, rot_label, torch.tensor(apply_rot, dtype=torch.bool)


# ----------------------------
# 4) Training & Eval (Fixed version)
# ----------------------------
def train_one_epoch(model, loader, optimizer, device,
                    lambda_rot=0.5, lambda_rc=0.0,
                    lambda_orbit=0.2, orbit_K=2, orbit_tau=0.2, orbit_max_deg=20.0):
    model.train()
    total_loss = total_act = total_rot = total_orbit = 0.0
    n = 0

    for batch_idx, (xb, yb, rb, msk) in enumerate(loader):
        try:
            xb, yb, rb, msk = xb.to(device), yb.to(device), rb.to(device), msk.to(device)

            # Forward pass
            lo_act, lo_rot, z = model(xb)
            loss_act = F.cross_entropy(lo_act, yb)

            # Rotation SSL loss (fixed gradient handling)
            if msk.any():
                # Only compute loss for rotated samples
                loss_rot = F.cross_entropy(lo_rot[msk], rb[msk])
            else:
                # Create a zero loss that maintains gradient flow
                loss_rot = torch.zeros_like(loss_act)

            # Update prototypes (no gradients)
            with torch.no_grad():
                model.update_prototypes(z, yb)

            # Proto-Orbit loss (fixed version)
            if lambda_orbit > 0.0 and model.proto_counts.sum() > 0:
                loss_orbit = proto_orbit_loss(model, xb, yb, K=orbit_K,
                                            max_deg=orbit_max_deg, temperature=orbit_tau)
            else:
                loss_orbit = torch.zeros_like(loss_act)

            # Total loss
            loss = loss_act + lambda_rot * loss_rot + lambda_orbit * loss_orbit

            # Optional rotation-consistency
            if lambda_rc > 0:
                rc = rotation_consistency_loss(model, xb, K=2, max_deg=15.0)
                loss = loss + lambda_rc * rc

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # Accumulate metrics
            bs = xb.size(0)
            n += bs
            total_loss += float(loss.item()) * bs
            total_act += float(loss_act.item()) * bs
            total_rot += float(loss_rot.item()) * bs
            total_orbit += float(loss_orbit.item()) * bs

        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue

    return total_loss/n, total_act/n, total_rot/n, total_orbit/n


@torch.no_grad()
def evaluate(model, loader, device, use_tta=False):
    model.eval()
    y_true, y_pred = [], []

    for xb, yb, _, _ in loader:
        xb = xb.to(device)
        y_true.append(yb.numpy())

        if use_tta:
            probs = roco_tta_predict(model, xb, steps=2, K=3, lr=3e-3,
                                   max_deg=20.0, weight_cons=1.0)
            pred = probs.argmax(dim=-1).cpu().numpy()
        else:
            lo, _, _ = model(xb)
            pred = lo.argmax(dim=-1).cpu().numpy()

        y_pred.append(pred)

    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="macro")
    return acc, f1


# ----------------------------
# 5) Main (Fixed version)
# ----------------------------
def main():
    # Use current working directory
    root = "/content/drive/MyDrive/Colab Notebooks"
    data_dir = download_and_extract_uci(root)

    Xtr, ytr, Xte, yte, label_names = load_uci_inertial(data_dir)
    print("[Data] Train:", Xtr.shape, " Test:", Xte.shape)

    # Standardize per-channel using train only
    scaler = StandardScaler()
    scaler.fit(Xtr.reshape(-1, Xtr.shape[2]))

    train_ds = UCIHARInertial(Xtr, ytr, scaler=scaler, train=True,
                              rot_ssl=True, num_rotations=4, p_rotate_ssl=0.7)
    test_ds = UCIHARInertial(Xte, yte, scaler=scaler, train=False,
                             rot_ssl=False, num_rotations=4, p_rotate_ssl=0.0)

    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=0, drop_last=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Device] Using: {device}")

    model = InceptionMK(input_channels=9, stem_out=64, block_out=32,
                        embedding_dim=128, num_classes=6, num_rotations=4).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25, eta_min=1e-5)

    best = {"acc": 0.0, "f1": 0.0, "state": None}
    epochs = 25

    for ep in range(1, epochs+1):
        tr_loss, tr_act, tr_rot, tr_orb = train_one_epoch(
            model, train_loader, optimizer, device,
            lambda_rot=0.3,      # Reduced for stability
            lambda_rc=0.0,       # Disabled for now
            lambda_orbit=0.1,    # Reduced for stability
            orbit_K=1,           # Reduced K
            orbit_tau=0.3,
            orbit_max_deg=15.0   # Reduced rotation
        )

        acc, f1 = evaluate(model, test_loader, device, use_tta=False)
        scheduler.step()

        print(f"[Epoch {ep:02d}] loss={tr_loss:.4f} "
              f"(act={tr_act:.4f}, rot={tr_rot:.4f}, orbit={tr_orb:.4f}) | "
              f"test acc={acc*100:.2f}% f1={f1:.4f}")

        if f1 > best["f1"]:
            best.update({"acc": acc, "f1": f1, "state": {k: v.cpu() for k, v in model.state_dict().items()}})
            print(f"  ★ New best F1: {f1:.4f}")

    # Load best model
    if best["state"] is not None:
        model.load_state_dict({k: v.to(device) for k, v in best["state"].items()})

    # Final evaluation
    acc_off, f1_off = evaluate(model, test_loader, device, use_tta=False)
    print(f"\n[Final - No TTA] ACC={acc_off*100:.2f}%  F1={f1_off:.4f}")

    try:
        acc_tta, f1_tta = evaluate(model, test_loader, device, use_tta=True)
        print(f"[Final - With TTA] ACC={acc_tta*100:.2f}%  F1={f1_tta:.4f}")
    except Exception as e:
        print(f"TTA evaluation failed: {e}")

    print("\n=== Training Complete ===")


if __name__ == "__main__":
    main()

[Data] Train: (7352, 128, 9)  Test: (2947, 128, 9)
[Device] Using: cuda
[Epoch 01] loss=0.7486 (act=0.4961, rot=0.5481, orbit=0.8802) | test acc=85.71% f1=0.8481
  ★ New best F1: 0.8481
[Epoch 02] loss=0.3765 (act=0.2559, rot=0.2286, orbit=0.5202) | test acc=90.36% f1=0.9037
  ★ New best F1: 0.9037
[Epoch 03] loss=0.3412 (act=0.2277, rot=0.2248, orbit=0.4607) | test acc=91.01% f1=0.9091
  ★ New best F1: 0.9091
[Epoch 04] loss=0.3240 (act=0.2218, rot=0.1963, orbit=0.4322) | test acc=88.70% f1=0.8859
[Epoch 05] loss=0.2900 (act=0.1986, rot=0.1676, orbit=0.4106) | test acc=92.03% f1=0.9201
  ★ New best F1: 0.9201
[Epoch 06] loss=0.2804 (act=0.1948, rot=0.1520, orbit=0.4000) | test acc=91.21% f1=0.9120
[Epoch 07] loss=0.2631 (act=0.1816, rot=0.1440, orbit=0.3831) | test acc=89.89% f1=0.8976
[Epoch 08] loss=0.2682 (act=0.1855, rot=0.1482, orbit=0.3830) | test acc=88.84% f1=0.8880
[Epoch 09] loss=0.2513 (act=0.1716, rot=0.1426, orbit=0.3688) | test acc=92.09% f1=0.9211
  ★ New best F1: 0.921

In [18]:
# ============================================================
# UCI-HAR (Inertial Signals, raw 9ch) Training & Evaluation
# InceptionMK + Rotation SSL + (optional) Rotation-Consistency + RoCo-TTA
# + Proto-Orbit Contrast (class-prototype-based rotation-invariant loss)
# ============================================================

import os
import io
import math
import zipfile
import urllib.request
from pathlib import Path
from typing import Tuple

import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ----------------------------
# 0) Utilities: download & load UCI-HAR (raw inertial)
# ----------------------------
UCI_URL = "https://archive.ics.uci.edu/static/public/240/human+activity+recognition+using+smartphones.zip"

def download_and_extract_uci(root: str):
    root = Path(root)
    data_dir = root / "UCI_HAR_Dataset"
    if data_dir.exists():
        return data_dir

    root.mkdir(parents=True, exist_ok=True)
    zip_path = root / "uci_har.zip"
    print("[UCI] Downloading dataset...")
    urllib.request.urlretrieve(UCI_URL, zip_path)

    print("[UCI] Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zf:
        zf.extractall(root)

    # The archive extracts a folder "UCI HAR Dataset"
    # unify name to UCI_HAR_Dataset
    original = root / "UCI HAR Dataset"
    if original.exists():
        original.rename(data_dir)
    zip_path.unlink(missing_ok=True)
    print("[UCI] Ready at:", str(data_dir))
    return data_dir

def _load_signal_file(path: Path) -> np.ndarray:
    # Each file: N lines, each line has 128 floats separated by space
    arr = np.loadtxt(path, dtype=np.float32)
    # shape: (N, 128)
    return arr

def load_uci_inertial(data_dir: Path):
    """
    Returns:
      X_train: (N_train, 128, 9), y_train: (N_train,)
      X_test:  (N_test, 128, 9), y_test:  (N_test,)
      label_names: list of 6 activities
    """
    base = data_dir

    # 9 inertial signals
    train_inertial = base / "train" / "Inertial Signals"
    test_inertial  = base / "test" / "Inertial Signals"

    files = [
        "total_acc_x", "total_acc_y", "total_acc_z",
        "body_acc_x",  "body_acc_y",  "body_acc_z",
        "body_gyro_x", "body_gyro_y", "body_gyro_z"
    ]
    # Build arrays by stacking channel-wise
    Xtr_list = []
    Xte_list = []
    for name in files:
        ftr = train_inertial / f"{name}_train.txt"
        fte = test_inertial  / f"{name}_test.txt"
        Xtr_list.append(_load_signal_file(ftr))  # (Ntr,128)
        Xte_list.append(_load_signal_file(fte))  # (Nte,128)

    # Stack -> (C, N, 128) then transpose
    X_train = np.stack(Xtr_list, axis=2)  # (Ntr,128,9)
    X_test  = np.stack(Xte_list, axis=2)  # (Nte,128,9)

    # Labels: 1..6
    y_train = np.loadtxt(base / "train" / "y_train.txt", dtype=np.int64).ravel() - 1
    y_test  = np.loadtxt(base / "test" / "y_test.txt", dtype=np.int64).ravel() - 1

    # Label names
    label_names = ["WALKING","WALKING_UPSTAIRS","WALKING_DOWNSTAIRS","SITTING","STANDING","LAYING"]

    return X_train, y_train, X_test, y_test, label_names


# ----------------------------
# 1) Rotation utilities (3-axis group-wise)
# ----------------------------
def _batched_rotation_matrix(angles):  # angles: (B, 3) in radians: rx, ry, rz
    B = angles.shape[0]
    rx, ry, rz = angles[:, 0], angles[:, 1], angles[:, 2]

    cx, sx = torch.cos(rx), torch.sin(rx)
    cy, sy = torch.cos(ry), torch.sin(ry)
    cz, sz = torch.cos(rz), torch.sin(rz)

    Rx = torch.zeros(B, 3, 3, device=angles.device)
    Rx[:, 0, 0] = 1.
    Rx[:, 1, 1] = cx; Rx[:, 1, 2] = -sx
    Rx[:, 2, 1] = sx; Rx[:, 2, 2] = cx

    Ry = torch.zeros(B, 3, 3, device=angles.device)
    Ry[:, 0, 0] = cy;  Ry[:, 0, 2] = sy
    Ry[:, 1, 1] = 1.
    Ry[:, 2, 0] = -sy; Ry[:, 2, 2] = cy

    Rz = torch.zeros(B, 3, 3, device=angles.device)
    Rz[:, 0, 0] = cz; Rz[:, 0, 1] = -sz
    Rz[:, 1, 0] = sz; Rz[:, 1, 1] = cz
    Rz[:, 2, 2] = 1.

    R = Rz @ Ry @ Rx
    return R  # (B, 3, 3)

def apply_random_rotation_xyz(x, max_deg=20.0, group_size=3):
    """
    x: (B, T, C)
    C must be multiple of 3. UCI inertial: 9ch -> OK.
    """
    B, T, C = x.shape
    assert C % group_size == 0
    G = C // group_size

    max_rad = max_deg * math.pi / 180.0
    angles = (torch.rand(B, 3, device=x.device) * 2 - 1) * max_rad
    R = _batched_rotation_matrix(angles)  # (B,3,3)

    x_rot = x.clone()
    for g in range(G):
        seg = x[:, :, g*group_size:(g+1)*group_size]      # (B, T, 3)
        seg2 = seg.reshape(B*T, 3)
        Rb = R.repeat_interleave(T, dim=0)                # (B*T,3,3)
        seg_rot = (Rb @ seg2.unsqueeze(-1)).squeeze(-1)   # (B*T,3)
        x_rot[:, :, g*group_size:(g+1)*group_size] = seg_rot.view(B, T, 3)
    return x_rot

# (1) make_rotation_class: 라벨 스칼라로 통일
def make_rotation_class(x, num_rotations=4):
    # x: (B,T,C)
    device = x.device
    B, T, C = x.shape
    angles_deg = torch.tensor([0., 90., 180., 270.], device=device)
    labels = torch.randint(0, len(angles_deg), (B,), device=device)  # (B,)
    rad = angles_deg[labels] * math.pi / 180.0

    angles = torch.stack([torch.zeros_like(rad), torch.zeros_like(rad), rad], dim=1)  # (B,3)
    R = _batched_rotation_matrix(angles)  # (B,3,3)

    x_rot = x.clone()
    G = C // 3
    for g in range(G):
        seg = x[:, :, g*3:(g+1)*3]           # (B,T,3)
        seg2 = seg.reshape(B*T, 3)
        Rb = R.repeat_interleave(T, dim=0)   # (B*T,3,3)
        seg_rot = (Rb @ seg2.unsqueeze(-1)).squeeze(-1)
        x_rot[:, :, g*3:(g+1)*3] = seg_rot.view(B, T, 3)

    # ★ 라벨 모양 통일: 배치가 1이면 스칼라로
    if B == 1:
        labels = labels.squeeze(0)  # shape: ()
    return x_rot, labels


# ----------------------------
# 2) Model (InceptionMK + adapter + TTA + Prototypes)
# ----------------------------
class DSConv1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=kernel_size // 2, groups=in_channels, bias=False)
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_channels)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        x = self.act(x)
        return x

class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.MaxPool1d(kernel_size=3, stride=1, padding=1),
            DSConv1D(in_channels, out_channels, kernel_size=1)
        )
        self.branch2 = DSConv1D(in_channels, out_channels, kernel_size=1)
        self.branch3 = DSConv1D(in_channels, out_channels, kernel_size=3)
        self.branch4 = DSConv1D(in_channels, out_channels, kernel_size=5)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = torch.cat([
            self.branch1(x), self.branch2(x),
            self.branch3(x), self.branch4(x)
        ], dim=1)
        return self.relu(out)

class MultiKernelBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = DSConv1D(in_channels, out_channels, kernel_size=1)
        self.branch2 = DSConv1D(in_channels, out_channels, kernel_size=3)
        self.branch3 = DSConv1D(in_channels, out_channels, kernel_size=5)
        self.branch4 = DSConv1D(in_channels, out_channels, kernel_size=7)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = torch.cat([
            self.branch1(x), self.branch2(x),
            self.branch3(x), self.branch4(x)
        ], dim=1)
        return self.relu(out)

class FeatureAdapter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim))
        self.beta  = nn.Parameter(torch.zeros(1, dim))
    def forward(self, z):
        return z * self.gamma + self.beta

class InceptionMK(nn.Module):
    def __init__(self, input_channels=9, stem_out=64, block_out=32,
                 embedding_dim=128, num_classes=6, num_rotations=4):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(input_channels, stem_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(stem_out),
            nn.ReLU(inplace=True)
        )
        self.inception = InceptionBlock(stem_out, block_out)
        self.mk_block  = MultiKernelBlock(block_out * 4, block_out)

        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc  = nn.Linear(block_out * 4, embedding_dim)
        self.adapter = FeatureAdapter(embedding_dim)

        self.head_act = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, 128),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        self.head_rot = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, 128),
            nn.Dropout(0.3),
            nn.Linear(128, num_rotations)
        )

        # --- Proto-Orbit: class prototypes (EMA) ---
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.register_buffer("prototypes", torch.zeros(num_classes, embedding_dim))
        self.proto_m = 0.99  # EMA momentum

    def forward(self, x):  # x: (B,T,C)
        x = x.transpose(1, 2)           # -> (B,C,T)
        x = self.stem(x)
        x = self.inception(x)
        x = self.mk_block(x)
        x = self.gap(x).squeeze(-1)     # (B, 4*block_out)
        z = self.fc(x)                  # (B, D)
        z = self.adapter(z)
        logits_act = self.head_act(z)
        logits_rot = self.head_rot(z)
        return logits_act, logits_rot, z

    @torch.no_grad()
    def predict(self, x):
        self.eval()
        lo, _, _ = self.forward(x)
        return lo

    @torch.no_grad()
    def update_prototypes(self, z, y):
        """
        z: (B, D) embedding (before heads), y: (B,) class labels
        EMA로 클래스 프로토타입 업데이트. 첫 업데이트는 직접 할당.
        """
        if z.numel() == 0:
            return
        z_n = F.normalize(z, dim=-1)
        classes = y.unique()
        for c in classes:
            mask = (y == c)
            if mask.any():
                m = z_n[mask].mean(dim=0)
                p = self.prototypes[c]
                if torch.all(p == 0):
                    self.prototypes[c] = m
                else:
                    self.prototypes[c] = self.proto_m * p + (1.0 - self.proto_m) * m


# ----------------------------
# 2.5) Rotation-consistency (optional)
# ----------------------------
def rotation_consistency_loss(model, x, K=2, max_deg=15.0, temperature=1.0):
    # Keep model in its current (training) mode; just detach base distribution.
    with torch.no_grad():
        base_logits, _, _ = model.forward(x)
        base_logits = base_logits / temperature
        p_base = F.softmax(base_logits, dim=-1)

    kl_sum = 0.0
    for _ in range(K):
        xr = apply_random_rotation_xyz(x, max_deg=max_deg)
        logits_r, _, _ = model.forward(xr)
        p_r = F.log_softmax(logits_r / temperature, dim=-1)
        kl_sum = kl_sum + F.kl_div(p_r, p_base, reduction='batchmean', log_target=False)

    return kl_sum / K


@torch.no_grad()
def _entropy(p):
    return -(p * (p.clamp_min(1e-8).log())).sum(dim=-1)

@torch.no_grad()
def roco_tta_predict(model, x, steps=3, K=4, lr=5e-3, max_deg=20.0, temperature=1.0, weight_cons=1.0):
    was_training = model.training
    model.eval()

    for p in model.parameters():
        p.requires_grad_(False)
    for p in model.adapter.parameters():
        p.requires_grad_(True)

    opt = torch.optim.Adam(model.adapter.parameters(), lr=lr)
    for _ in range(steps):
        opt.zero_grad(set_to_none=True)
        lo0, _, _ = model.forward(x)
        p0 = F.softmax(lo0 / temperature, dim=-1)
        ent = _entropy(p0).mean()

        kl = 0.0
        for _k in range(K):
            xr = apply_random_rotation_xyz(x, max_deg=max_deg)
            lor, _, _ = model.forward(xr)
            pr = F.log_softmax(lor / temperature, dim=-1)
            kl += F.kl_div(pr, p0.detach(), reduction='batchmean', log_target=False)
        kl = kl / K

        loss = ent + weight_cons * kl
        loss.backward()
        opt.step()

    probs = []
    for _k in range(K):
        xr = apply_random_rotation_xyz(x, max_deg=max_deg)
        lor, _, _ = model.forward(xr)
        probs.append(F.softmax(lor, dim=-1))
    probs = torch.stack(probs, dim=0).mean(dim=0)

    for p in model.parameters():
        p.requires_grad_(True)
    if was_training:
        model.train()
    return probs


# ----------------------------
# 2.7) Proto-Orbit Contrast (InfoNCE/CE with class prototypes)
# ----------------------------

def proto_orbit_loss(model, x, y, K=2, max_deg=20.0, temperature=0.2):
    """
    동일 샘플의 여러 회전 오빗 임베딩 z_r를 '클래스 프로토타입'에 수축.
    logits = z_r · proto^T / tau  →  CE(logits, y)
    """
    proto = model.prototypes.detach()  # Detach prototypes (they're updated via EMA)
    proto_n = F.normalize(proto, dim=-1)

    loss_sum = torch.tensor(0.0, device=x.device)
    for _ in range(K):
        xr = apply_random_rotation_xyz(x, max_deg=max_deg)
        _, _, z_r = model(xr)                 # (B, D)
        z_r = F.normalize(z_r, dim=-1)        # (B, D)
        logits = (z_r @ proto_n.t()) / temperature  # (B, C)
        loss_sum = loss_sum + F.cross_entropy(logits, y)

    return loss_sum / K


# ----------------------------
# 3) Dataset & Dataloader
# ----------------------------
class UCIHARInertial(Dataset):
    def __init__(self, X, y, scaler: StandardScaler=None, train: bool=True,
                 rot_ssl: bool=True, num_rotations: int=4, p_rotate_ssl: float=0.5):
        """
        X: (N, 128, 9), y: (N,)
        scaler: StandardScaler fitted on train (per-channel)
        rot_ssl: if True, returns rotation SSL targets
        """
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.train = train
        self.rot_ssl = rot_ssl
        self.num_rot = num_rotations
        self.p_rotate_ssl = p_rotate_ssl

        # standardization per-channel using train scaler
        self.scaler = scaler
        if self.scaler is not None:
            N, T, C = self.X.shape
            X2 = self.X.reshape(-1, C)  # (N*T, C)
            X2 = self.scaler.transform(X2)
            self.X = X2.reshape(N, T, C)

    def __len__(self): return len(self.y)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])  # (128,9)
        y = torch.tensor(self.y[idx], dtype=torch.long)

        # Initialize with default values
        rot_label = torch.tensor(0, dtype=torch.long)
        x_ssl = x.clone()
        apply_rot = 0

        if self.train and self.rot_ssl and np.random.rand() < self.p_rotate_ssl:
            apply_rot = 1
            x_ssl_batch, rot_label_batch = make_rotation_class(x.unsqueeze(0), num_rotations=self.num_rot)
            x_ssl = x_ssl_batch.squeeze(0)
            if rot_label_batch.dim() > 0:
                rot_label = rot_label_batch.squeeze(0).long()
            else:
                rot_label = rot_label_batch.long()

        # Return mask as bool tensor
        msk = torch.tensor(apply_rot, dtype=torch.bool)
        return x_ssl, y, rot_label, msk


# ----------------------------
# 4) Training & Eval
# ----------------------------

def train_one_epoch(model, loader, optimizer, device,
                    lambda_rot=0.5, lambda_rc=0.0,
                    lambda_orbit=0.2, orbit_K=2, orbit_tau=0.2, orbit_max_deg=20.0):
    model.train()
    total_loss = total_act = total_rot = total_orbit = 0.0
    n = 0

    for xb, yb, rb, msk in loader:
        xb, yb, rb, msk = xb.to(device), yb.to(device), rb.to(device), msk.to(device)

        lo_act, lo_rot, z = model(xb)
        loss_act = F.cross_entropy(lo_act, yb)

        # --- Rotation SSL (회전 샘플이 없으면 그래프 연결된 0 사용)
        if msk.any():
            loss_rot = F.cross_entropy(lo_rot[msk], rb[msk])
        else:
            loss_rot = 0.0 * lo_act.sum()

        # --- 프로토타입 EMA 업데이트 (그래프 외부)
        with torch.no_grad():
            model.update_prototypes(z, yb)

        # --- Proto-Orbit (끄면 그래프 연결된 0)
        if lambda_orbit > 0.0:
            loss_orbit = proto_orbit_loss(model, xb, yb,
                                          K=orbit_K, max_deg=orbit_max_deg, temperature=orbit_tau)
        else:
            loss_orbit = 0.0 * lo_act.sum()

        loss = loss_act + lambda_rot * loss_rot + lambda_orbit * loss_orbit

        # (선택) Rotation-consistency
        if lambda_rc > 0.0:
            rc = rotation_consistency_loss(model, xb, K=2, max_deg=15.0)
            loss = loss + lambda_rc * rc

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        bs = xb.size(0)
        n += bs
        total_loss  += float(loss.detach()) * bs
        total_act   += float(loss_act.detach()) * bs
        total_rot   += float(loss_rot.detach()) * bs
        total_orbit += float(loss_orbit.detach()) * bs

    return total_loss/n, total_act/n, total_rot/n, total_orbit/n



@torch.no_grad()
def evaluate(model, loader, device, use_tta=False):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _, _ in loader:
        xb = xb.to(device)
        y_true.append(yb.numpy())
        if use_tta:
            probs = roco_tta_predict(model, xb, steps=2, K=3, lr=3e-3, max_deg=20.0, weight_cons=1.0)
            pred = probs.argmax(dim=-1).cpu().numpy()
        else:
            lo, _, _ = model(xb)
            pred = lo.argmax(dim=-1).cpu().numpy()
        y_pred.append(pred)

    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred, average="macro")
    return acc, f1


# ----------------------------
# 5) Main
# ----------------------------
def main():
    # UPDATED PATH: Use current directory instead of Google Drive
    root = "/content/drive/MyDrive/Colab Notebooks/"
    data_dir = download_and_extract_uci(root)

    Xtr, ytr, Xte, yte, label_names = load_uci_inertial(data_dir)
    print("[Data] Train:", Xtr.shape, " Test:", Xte.shape)

    # Standardize per-channel using train only
    scaler = StandardScaler()
    scaler.fit(Xtr.reshape(-1, Xtr.shape[2]))  # (N*T, C)

    train_ds = UCIHARInertial(Xtr, ytr, scaler=scaler, train=True,
                              rot_ssl=True, num_rotations=4, p_rotate_ssl=0.7)
    test_ds  = UCIHARInertial(Xte, yte, scaler=scaler, train=False,
                              rot_ssl=False, num_rotations=4, p_rotate_ssl=0.0)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=0, drop_last=False)
    test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=0, drop_last=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = InceptionMK(input_channels=9, stem_out=64, block_out=32,
                         embedding_dim=128, num_classes=6, num_rotations=4).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)

    best = {"acc": 0.0, "f1": 0.0, "state": None}
    epochs = 25
    for ep in range(1, epochs+1):
        tr_loss, tr_act, tr_rot, tr_orb = train_one_epoch(
            model, train_loader, optimizer, device,
            lambda_rot=0.5,      # rotation SSL loss weight
            lambda_rc=0.0,       # rotation-consistency KL (off by default)
            lambda_orbit=0.2,    # Proto-Orbit loss weight
            orbit_K=2,           # # rotated views per sample
            orbit_tau=0.2,       # InfoNCE temperature
            orbit_max_deg=20.0   # rotation range for orbit
        )
        acc, f1 = evaluate(model, test_loader, device, use_tta=False)
        print(f"[Epoch {ep:02d}] loss={tr_loss:.4f} (act={tr_act:.4f}, rot={tr_rot:.4f}, orbit={tr_orb:.4f}) | "
              f"test acc={acc*100:.2f} f1={f1:.4f}")

        if f1 > best["f1"]:
            best.update({"acc": acc, "f1": f1, "state": {k:v.cpu() for k,v in model.state_dict().items()}})

    # Load best (optional)
    if best["state"] is not None:
        model.load_state_dict({k:v.to(device) for k,v in best["state"].items()})

    # Final: TTA OFF vs ON
    acc_off, f1_off = evaluate(model, test_loader, device, use_tta=False)
    acc_tta, f1_tta = evaluate(model, test_loader, device, use_tta=True)

    print("\n=== Final Results (UCI-HAR inertial) ===")
    print(f"Base  : ACC={acc_off*100:.2f}%  Macro-F1={f1_off:.4f}")
    print(f"TTA ON: ACC={acc_tta*100:.2f}%  Macro-F1={f1_tta:.4f}")

if __name__ == "__main__":
    main()


# Additional debugging and improvement suggestions:

def debug_batch_info(loader, device, model=None):
    """Debug function to check batch composition"""
    print("\n=== Debugging Batch Information ===")
    for i, (xb, yb, rb, msk) in enumerate(loader):
        if i >= 3:  # Only check first 3 batches
            break
        xb, yb, rb, msk = xb.to(device), yb.to(device), rb.to(device), msk.to(device)
        print(f"Batch {i}: shape={xb.shape}, rotated_samples={msk.sum().item()}/{len(msk)}")
        if model is not None and msk.any():
            with torch.no_grad():
                lo_act, lo_rot, z = model(xb)
                print(f"  -> logits shapes: act={lo_act.shape}, rot={lo_rot.shape}, z={z.shape}")
                if msk.any():
                    print(f"  -> rotated logits: {lo_rot[msk].shape}, labels: {rb[msk].shape}")
        print()


# Training with better error handling
def safe_train_one_epoch(model, loader, optimizer, device,
                        lambda_rot=0.5, lambda_rc=0.0,
                        lambda_orbit=0.2, orbit_K=2, orbit_tau=0.2, orbit_max_deg=20.0):
    model.train()
    total_loss = total_act = total_rot = total_orbit = 0.0
    n = 0

    for batch_idx, (xb, yb, rb, msk) in enumerate(loader):
        try:
            xb, yb, rb, msk = xb.to(device), yb.to(device), rb.to(device), msk.to(device)

            lo_act, lo_rot, z = model(xb)
            loss_act = F.cross_entropy(lo_act, yb)

            # Safer rotation SSL handling
            if msk.any():
                rotated_indices = msk.nonzero(as_tuple=True)[0]
                if len(rotated_indices) > 0:
                    loss_rot = F.cross_entropy(lo_rot[rotated_indices], rb[rotated_indices])
                else:
                    loss_rot = torch.zeros_like(loss_act)
            else:
                loss_rot = torch.zeros_like(loss_act)

            # EMA update for class prototypes
            with torch.no_grad():
                model.update_prototypes(z, yb)

            # Proto-Orbit loss
            if lambda_orbit > 0.0:
                try:
                    loss_orbit = proto_orbit_loss(model, xb, yb,
                                                K=orbit_K, max_deg=orbit_max_deg, temperature=orbit_tau)
                except Exception as e:
                    print(f"Proto-orbit loss error in batch {batch_idx}: {e}")
                    loss_orbit = torch.zeros_like(loss_act)
            else:
                loss_orbit = torch.zeros_like(loss_act)

            loss = loss_act + lambda_rot * loss_rot + lambda_orbit * loss_orbit

            # Optional rotation-consistency KL
            if lambda_rc > 0:
                try:
                    rc = rotation_consistency_loss(model, xb, K=2, max_deg=15.0)
                    loss = loss + lambda_rc * rc
                except Exception as e:
                    print(f"Rotation consistency loss error in batch {batch_idx}: {e}")

            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            bs = xb.size(0)
            n += bs
            total_loss += float(loss.detach()) * bs
            total_act += float(loss_act.detach()) * bs
            total_rot += float(loss_rot.detach()) * bs
            total_orbit += float(loss_orbit.detach()) * bs

        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            print(f"Batch shapes: xb={xb.shape if 'xb' in locals() else 'N/A'}, "
                  f"yb={yb.shape if 'yb' in locals() else 'N/A'}, "
                  f"rb={rb.shape if 'rb' in locals() else 'N/A'}, "
                  f"msk={msk.shape if 'msk' in locals() else 'N/A'}")
            continue

    return total_loss/n, total_act/n, total_rot/n, total_orbit/n


# Alternative main function with better debugging
def main_debug():
    """Main function with extensive debugging"""
    print("Starting UCI-HAR training with debugging...")

    root = "/content/drive/MyDrive/Colab Notebooks"
    data_dir = download_and_extract_uci(root)

    Xtr, ytr, Xte, yte, label_names = load_uci_inertial(data_dir)
    print(f"[Data] Train: {Xtr.shape}, Test: {Xte.shape}")
    print(f"[Data] Labels: {np.unique(ytr)} -> {label_names}")

    # Data standardization
    scaler = StandardScaler()
    scaler.fit(Xtr.reshape(-1, Xtr.shape[2]))

    train_ds = UCIHARInertial(Xtr, ytr, scaler=scaler, train=True,
                              rot_ssl=True, num_rotations=4, p_rotate_ssl=0.7)
    test_ds = UCIHARInertial(Xte, yte, scaler=scaler, train=False,
                             rot_ssl=False, num_rotations=4, p_rotate_ssl=0.0)

    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=0, drop_last=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Device] Using: {device}")

    model = InceptionMK(input_channels=9, stem_out=64, block_out=32,
                        embedding_dim=128, num_classes=6, num_rotations=4).to(device)

    # Debug batch info
    debug_batch_info(train_loader, device, model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5)

    best = {"acc": 0.0, "f1": 0.0, "state": None}
    epochs = 20  # Reduced for debugging

    for ep in range(1, epochs + 1):
        tr_loss, tr_act, tr_rot, tr_orb = safe_train_one_epoch(
            model, train_loader, optimizer, device,
            lambda_rot=0.3,      # Reduced weight
            lambda_rc=0.0,       # Disabled for stability
            lambda_orbit=0.1,    # Reduced weight
            orbit_K=1,           # Reduced K for stability
            orbit_tau=0.3,
            orbit_max_deg=15.0   # Reduced rotation
        )

        acc, f1 = evaluate(model, test_loader, device, use_tta=False)
        scheduler.step()

        print(f"[Epoch {ep:02d}] loss={tr_loss:.4f} "
              f"(act={tr_act:.4f}, rot={tr_rot:.4f}, orbit={tr_orb:.4f}) | "
              f"test acc={acc*100:.2f}% f1={f1:.4f} | lr={optimizer.param_groups[0]['lr']:.2e}")

        if f1 > best["f1"]:
            best.update({"acc": acc, "f1": f1, "state": {k: v.cpu() for k, v in model.state_dict().items()}})
            print(f"  ★ New best F1: {f1:.4f}")

    # Load best model
    if best["state"] is not None:
        model.load_state_dict({k: v.to(device) for k, v in best["state"].items()})
        print(f"\n[Best Model] Loaded with F1: {best['f1']:.4f}")

    # Final evaluation
    acc_off, f1_off = evaluate(model, test_loader, device, use_tta=False)
    print(f"\n[Final Eval - No TTA] ACC={acc_off*100:.2f}%  F1={f1_off:.4f}")

    # TTA evaluation (optional, might be slow)
    try:
        acc_tta, f1_tta = evaluate(model, test_loader, device, use_tta=True)
        print(f"[Final Eval - With TTA] ACC={acc_tta*100:.2f}%  F1={f1_tta:.4f}")
    except Exception as e:
        print(f"TTA evaluation failed: {e}")

    print("\n=== Training Complete ===")
    return model, best


# Uncomment to run the debug version instead:
# main_debug()

[Data] Train: (7352, 128, 9)  Test: (2947, 128, 9)
[Epoch 01] loss=1.5389 (act=0.8742, rot=0.9402, orbit=0.9729) | test acc=86.05 f1=0.8587
[Epoch 02] loss=0.6819 (act=0.3839, rot=0.3197, orbit=0.6905) | test acc=90.09 f1=0.9006
[Epoch 03] loss=0.4835 (act=0.2698, rot=0.2346, orbit=0.4819) | test acc=91.18 f1=0.9107
[Epoch 04] loss=0.4018 (act=0.2315, rot=0.1837, orbit=0.3924) | test acc=91.38 f1=0.9132
[Epoch 05] loss=0.3639 (act=0.2152, rot=0.1548, orbit=0.3563) | test acc=90.46 f1=0.9040
[Epoch 06] loss=0.3284 (act=0.1982, rot=0.1273, orbit=0.3330) | test acc=91.35 f1=0.9131
[Epoch 07] loss=0.3279 (act=0.1941, rot=0.1385, orbit=0.3228) | test acc=91.14 f1=0.9112
[Epoch 08] loss=0.3059 (act=0.1838, rot=0.1209, orbit=0.3084) | test acc=90.60 f1=0.9052
[Epoch 09] loss=0.2907 (act=0.1778, rot=0.1043, orbit=0.3036) | test acc=91.25 f1=0.9121
[Epoch 10] loss=0.2853 (act=0.1733, rot=0.1048, orbit=0.2983) | test acc=91.69 f1=0.9166
[Epoch 11] loss=0.2710 (act=0.1671, rot=0.0912, orbit=0.291

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn