In [7]:
# ============ RCCMix-HAR (Step 10, GeoContextHAR) – structure & size ============

import json
from pathlib import Path

import torch
import torch.nn as nn

print("\n[RCCMix-HAR (Step 10, GeoContextHAR) – structure & size]")

# ---------------------------
# 1) Determine NUM_CLASSES
# ---------------------------
BASE = Path("/content")
CFG_DIR = BASE / "configs"

if (CFG_DIR / "classes.json").exists():
    with open(CFG_DIR / "classes.json", "r") as f:
        classes_cfg = json.load(f)
    NUM_CLASSES = int(classes_cfg["num_classes"])
    print(f"Detected NUM_CLASSES from configs: {NUM_CLASSES}")
else:
    # Change this if your setup uses a different number of classes
    NUM_CLASSES = 8
    print("Warning: /content/configs/classes.json not found. Using default NUM_CLASSES = 8.")
    print("Please update NUM_CLASSES manually if this does not match your setup.")

# ---------------------------
# 2) Hyperparameters (must match your RCCMix-HAR Step 10 script)
# ---------------------------
IN_CHANNELS  = 6         # acc+gyro
D_MODEL      = 128
N_HEADS      = 4
N_LAYERS     = 2
D_FF         = 4 * D_MODEL   # 512
DROPOUT      = 0.2
SEQ_LEN      = 8

print(f"\nConfig for size check:")
print(f"  NUM_CLASSES = {NUM_CLASSES}")
print(f"  IN_CHANNELS = {IN_CHANNELS}")
print(f"  D_MODEL     = {D_MODEL}")
print(f"  N_HEADS     = {N_HEADS}")
print(f"  N_LAYERS    = {N_LAYERS}")
print(f"  D_FF        = {D_FF}")
print(f"  DROPOUT     = {DROPOUT}")
print(f"  SEQ_LEN     = {SEQ_LEN}")

# ---------------------------
# 3) Model definition (identical to your script)
# ---------------------------
class DepthwiseSeparableConv1d(nn.Module):
    def __init__(self, in_ch, out_ch, k, dilation=1, dropout=0.0):
        super().__init__()
        pad = (k // 2) * dilation
        self.dw = nn.Conv1d(
            in_ch, in_ch, kernel_size=k,
            padding=pad, dilation=dilation,
            groups=in_ch, bias=False
        )
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)
        x = self.bn(x)
        x = self.act(x)
        return self.drop(x)


class WindowEncoder(nn.Module):
    """
    rTsfNet-style per-window representation:
      - 6D IMU + acc_norm + gyro_norm
      - multi-scale depthwise separable conv branches
      - global average pooling over time
      - geometric conditioning vector g from RMS and energy
    """
    def __init__(self, in_ch=6, d_model=128, dropout=0.2):
        super().__init__()
        self.in_ch = in_ch
        self.aug_ch = in_ch + 2   # + acc_norm + gyro_norm

        self.b1 = DepthwiseSeparableConv1d(self.aug_ch, d_model // 2, k=9,  dilation=1, dropout=dropout)
        self.b2 = DepthwiseSeparableConv1d(self.aug_ch, d_model // 2, k=19, dilation=2, dropout=dropout)
        self.mix = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
        self.bn  = nn.BatchNorm1d(d_model)
        self.act = nn.GELU()
        self.drop= nn.Dropout(dropout)

        self.g_proj = nn.Sequential(
            nn.Linear(4, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model)
        )

    def forward(self, x):
        # x: [B*L, C=6, T]
        BL, C, T = x.shape

        acc_norm = torch.sqrt(
            x[:, 0, :]**2 + x[:, 1, :]**2 + x[:, 2, :]**2 + 1e-8
        ).unsqueeze(1)  # [B*L,1,T]
        gyr_norm = torch.sqrt(
            x[:, 3, :]**2 + x[:, 4, :]**2 + x[:, 5, :]**2 + 1e-8
        ).unsqueeze(1)  # [B*L,1,T]
        x_aug = torch.cat([x, acc_norm, gyr_norm], dim=1)  # [B*L, 8, T]

        z = torch.cat([self.b1(x_aug), self.b2(x_aug)], dim=1)  # [B*L, d_model, T]
        z = self.mix(z)
        z = self.bn(z)
        z = self.act(z)
        z = self.drop(z)

        token = z.mean(dim=-1)  # [B*L, d_model]

        acc_rms = acc_norm.squeeze(1).pow(2).mean(dim=-1).sqrt()
        gyr_rms = gyr_norm.squeeze(1).pow(2).mean(dim=-1).sqrt()
        acc_en  = x[:, 0:3, :].pow(2).mean(dim=(1, 2)).sqrt()
        gyr_en  = x[:, 3:6, :].pow(2).mean(dim=(1, 2)).sqrt()
        g = torch.stack([acc_rms, gyr_rms, acc_en, gyr_en], dim=-1)  # [B*L, 4]
        g = self.g_proj(g)  # [B*L, d_model]

        return token, g


class CondLayerNorm(nn.Module):
    """FiLM-style conditional LayerNorm: LN(x) * (1 + gamma(g)) + beta(g)."""
    def __init__(self, d_model):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.gamma = nn.Linear(d_model, d_model)
        self.beta  = nn.Linear(d_model, d_model)

    def forward(self, x, g):
        y = self.ln(x)
        return y * (1 + self.gamma(g)) + self.beta(g)


class RCCBlock(nn.Module):
    """Rotation-conditioned Transformer encoder block."""
    def __init__(self, d_model=128, n_heads=4, d_ff=512, dropout=0.2):
        super().__init__()
        self.condln1 = CondLayerNorm(d_model)
        self.mha = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        self.drop1 = nn.Dropout(dropout)

        self.condln2 = CondLayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x, g):
        # x, g: [B, L(+1), d]
        y = self.condln1(x, g)
        attn, _ = self.mha(y, y, y, need_weights=False)
        x = x + self.drop1(attn)

        y = self.condln2(x, g)
        y = self.ff(y)
        x = x + self.drop2(y)
        return x


class GeoContextHAR(nn.Module):
    """
    RCCMix-HAR main body:
      WindowEncoder + rotation-conditioned Transformer + CLS head.
    """
    def __init__(self, in_ch=6, d_model=128, n_layers=2, n_heads=4, d_ff=512,
                 dropout=0.2, seq_len=8, num_classes=8):
        super().__init__()
        self.seq_len = seq_len
        self.encoder = WindowEncoder(in_ch=in_ch, d_model=d_model, dropout=dropout)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos = nn.Parameter(torch.zeros(1, seq_len + 1, d_model))
        self.blocks = nn.ModuleList(
            [RCCBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
        )
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

        nn.init.trunc_normal_(self.pos, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        # x: [B, L, C, T]
        B, L, C, T = x.shape
        x = x.reshape(B * L, C, T)
        token, g = self.encoder(x)              # [B*L, d_model], [B*L, d_model]
        token = token.view(B, L, -1)           # [B, L, d_model]
        g     = g.view(B, L, -1)               # [B, L, d_model]

        cls = self.cls_token.expand(B, 1, -1)  # [B, 1, d_model]
        z = torch.cat([cls, token], dim=1)     # [B, L+1, d_model]
        g_cls = g.mean(dim=1, keepdim=True)    # [B, 1, d_model]
        g_all = torch.cat([g_cls, g], dim=1)   # [B, L+1, d_model]

        z = z + self.pos

        for blk in self.blocks:
            z = blk(z, g_all)

        z = self.norm(z)
        cls_rep = z[:, 0, :]                   # [B, d_model]
        logits = self.head(cls_rep)            # [B, num_classes]
        return logits

# ---------------------------
# 4) Instantiate model and compute size
# ---------------------------
model = GeoContextHAR(
    in_ch=IN_CHANNELS,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_ff=D_FF,
    dropout=DROPOUT,
    seq_len=SEQ_LEN,
    num_classes=NUM_CLASSES
)

print("\n====== nn.Module structure ======\n")
print(model)

# Parameter counts
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\n====== Parameter statistics ======")
print(f"Total params:      {total_params:,}")
print(f"Trainable params:  {trainable_params:,}")

print("\n====== Per-layer parameter counts ======")
for name, p in model.named_parameters():
    print(f"{name:45s} shape={tuple(p.shape)}  params={p.numel():,}")

# Size estimation (parameters only)
def fmt_mb(n_bytes: int) -> str:
    return f"{n_bytes / 1024 / 1024:.2f} MB"

bytes_fp32 = total_params * 4
bytes_fp16 = total_params * 2

print("\n====== Model size estimate (parameters only) ======")
print(f"FP32 (float32, 4B/param): {fmt_mb(bytes_fp32)}")
print(f"FP16 (float16, 2B/param): {fmt_mb(bytes_fp16)}")

# Save a randomly initialised state_dict to check actual .pt size
models_dir = BASE / "models"
models_dir.mkdir(parents=True, exist_ok=True)
tmp_path = models_dir / "rccmix_har_step10_dummy.pt"
torch.save(model.state_dict(), tmp_path)
file_bytes = tmp_path.stat().st_size
print(f"\nRandom-initialised state_dict saved to: {tmp_path.name}")
print(f"Actual .pt file size:                 {fmt_mb(file_bytes)}")

print("\n[RCCMix-HAR (Step 10, GeoContextHAR) – structure & size done]\n")


[RCCMix-HAR (Step 10, GeoContextHAR) – structure & size]
Please update NUM_CLASSES manually if this does not match your setup.

Config for size check:
  NUM_CLASSES = 8
  IN_CHANNELS = 6
  D_MODEL     = 128
  N_HEADS     = 4
  N_LAYERS    = 2
  D_FF        = 512
  DROPOUT     = 0.2
  SEQ_LEN     = 8


GeoContextHAR(
  (encoder): WindowEncoder(
    (b1): DepthwiseSeparableConv1d(
      (dw): Conv1d(8, 8, kernel_size=(9,), stride=(1,), padding=(4,), groups=8, bias=False)
      (pw): Conv1d(8, 64, kernel_size=(1,), stride=(1,), bias=False)
      (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): GELU(approximate='none')
      (drop): Dropout(p=0.2, inplace=False)
    )
    (b2): DepthwiseSeparableConv1d(
      (dw): Conv1d(8, 8, kernel_size=(19,), stride=(1,), padding=(18,), dilation=(2,), groups=8, bias=False)
      (pw): Conv1d(8, 64, kernel_size=(1,), stride=(1,), bias=False)
      (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=Tr