In [4]:
# ================================================================
# rTsfNet (Step 10 official architecture aligned · TSF-Mixer variant)
# × Option 1: NVML-based GPU inference energy — mJ per window (Colab one-click)
# ================================================================
# 0) Environment check & dependencies
!nvidia-smi
!pip -q install pynvml

import os, math, json, time, pathlib, gc, warnings, multiprocessing as mp
from pathlib import Path
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

# ---------------- Option 1: NVML sampling + numerical integration + idle subtraction (per-window reporting) ----------------
import pynvml

def _nvml_sampler(stop_event, q, dev_index=0, interval=0.02):
    """Subprocess: sample power (mW) every `interval` and send back (t_abs, mW)."""
    import time, pynvml
    pynvml.nvmlInit()
    h = pynvml.nvmlDeviceGetHandleByIndex(dev_index)
    try:
        while not stop_event.is_set():
            q.put((time.perf_counter(), pynvml.nvmlDeviceGetPowerUsage(h)))
            time.sleep(interval)
    finally:
        pynvml.nvmlShutdown()

def _integrate_mJ_between(samples, t0, t1):
    """Trapezoidal integration of power (mW) over [t0, t1], returning mJ."""
    if not samples:
        return 0.0
    samples = sorted(samples, key=lambda x: x[0])
    ts = np.array([t for t, _ in samples], dtype=np.float64)
    ps = np.array([p for _, p in samples], dtype=np.float64)
    m = (ts >= t0) & (ts <= t1)
    ts_w, ps_w = ts[m], ps[m]
    if ts_w.size == 0 or ts_w[0] > t0:
        p0 = np.interp(t0, ts, ps); ts_w = np.insert(ts_w, 0, t0); ps_w = np.insert(ps_w, 0, p0)
    if ts_w[-1] < t1:
        p1 = np.interp(t1, ts, ps); ts_w = np.append(ts_w, t1); ps_w = np.append(ps_w, p1)
    return float(np.trapz(ps_w, ts_w))  # mW*s = mJ

def sample_idle_power_mW(duration_s=20.0, dev_index=0, interval=0.02, save_csv=None):
    """Measure mean idle power (mW), optionally saving the power trace."""
    import time
    q = mp.Queue(); stop = mp.Event()
    p = mp.Process(target=_nvml_sampler, args=(stop, q, dev_index, interval))
    p.start()
    time.sleep(duration_s)
    stop.set(); p.join()
    samples = []
    while not q.empty():
        samples.append(q.get())
    if not samples:
        raise RuntimeError("NVML captured no idle power samples.")
    samples.sort(key=lambda x: x[0])
    t0, t1 = samples[0][0], samples[-1][0]
    E_idle_mJ = _integrate_mJ_between(samples, t0, t1)
    T_idle_s = max(1e-9, t1 - t0)
    P_idle_mW = E_idle_mJ / T_idle_s
    if save_csv:
        pd.DataFrame(samples, columns=["t_abs_s", "power_mW"]).to_csv(save_csv, index=False)
    return P_idle_mW, samples

def measure_mJ_per_window(run_once, n_windows_per_call, repeats, P_idle_mW,
                          dev_index=0, interval=0.02, save_csv=None):
    """Concurrent NVML sampling + integration + idle subtraction; return per-window energy & latency."""
    import time
    q = mp.Queue(); stop = mp.Event()
    p = mp.Process(target=_nvml_sampler, args=(stop, q, dev_index, interval))
    p.start()

    t0 = time.perf_counter()
    for _ in range(repeats):
        run_once()
    t1 = time.perf_counter()

    stop.set(); p.join()
    samples = []
    while not q.empty():
        samples.append(q.get())
    if not samples:
        raise RuntimeError("NVML captured no power samples (active phase).")

    E_total_mJ = _integrate_mJ_between(samples, t0, t1)
    T_total_s  = max(1e-9, t1 - t0)
    E_idle_mJ  = P_idle_mW * T_total_s
    n_windows  = max(1, repeats * n_windows_per_call)

    if save_csv:
        pd.DataFrame(samples, columns=["t_abs_s", "power_mW"]).to_csv(save_csv, index=False)

    return {
        "mJ_per_window": max(0.0, (E_total_mJ - E_idle_mJ) / n_windows),
        "ms_per_window": (T_total_s / n_windows) * 1e3,
        "throughput_windows_per_s": n_windows / T_total_s,
        "n_windows": n_windows,
        "repeats": repeats,
        "T_total_s": T_total_s,
        "E_total_mJ": E_total_mJ,
        "E_idle_mJ": E_idle_mJ,
        "P_idle_mW": P_idle_mW,
        "t0_abs": t0, "t1_abs": t1
    }

def calibrate_repeats(run_once, target_s=8.0, min_rep=3, max_rep=5000):
    """Estimate repeats such that one measurement window ≈ target_s (warm-up to avoid first-call overhead)."""
    import time
    run_once()
    t0 = time.perf_counter(); run_once(); t1 = time.perf_counter()
    dt = max(1e-4, t1 - t0)
    reps = int(np.ceil(target_s / dt))
    return int(np.clip(reps, min_rep, max_rep))

def measure_with_bootstrap(name, run_once, n_windows, repeats, n_runs=5, n_boot=1000, logdir=Path("logs")):
    """Repeat n_runs, compute bootstrap 95% CI, and save traces/summary (per-window metrics)."""
    logdir.mkdir(exist_ok=True)
    res_list = []
    for i in range(n_runs):
        print(f"[Measure] {name} run {i+1}/{n_runs} ...")
        r = measure_mJ_per_window(
            run_once, n_windows, repeats, P_idle_mW,
            dev_index=0, interval=0.02,
            save_csv=str(logdir / f"power_trace_{name}_run{i+1}.csv")
        )
        res_list.append(r)

    mJ = np.array([r["mJ_per_window"] for r in res_list], dtype=np.float64)
    ms = np.array([r["ms_per_window"] for r in res_list], dtype=np.float64)
    rng = np.random.default_rng(123)
    boots = [float(np.mean(mJ[rng.integers(0, len(mJ), size=len(mJ))])) for _ in range(n_boot)]
    ci_low, ci_high = np.percentile(boots, [2.5, 97.5])

    summary = {
        "model": name,
        "mean_mJ_per_window": float(mJ.mean()),
        "ci95_low_mJ": float(ci_low),
        "ci95_high_mJ": float(ci_high),
        "mean_ms_per_window": float(ms.mean()),
        "runs": res_list
    }
    with open(logdir / f"energy_{name}.json", "w") as f:
        json.dump(summary, f, indent=2)
    print(f"[Result] {name}: {summary['mean_mJ_per_window']:.3f} mJ per window "
          f"(95% CI [{summary['ci95_low_mJ']:.3f}, {summary['ci95_high_mJ']:.3f}]); "
          f"{summary['mean_ms_per_window']:.3f} ms per window")
    return summary

# ---------------- rTsfNet (Step 10: official architecture aligned · TSF-Mixer etc.) [architecture/hyperparameters unchanged] ----------------
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.layers import (
    Dense, Dropout, LayerNormalization, LeakyReLU,
    Layer, Activation, TimeDistributed, Flatten, Concatenate, GlobalAveragePooling1D
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
SEED = 42
tf.random.set_seed(SEED); np.random.seed(SEED)
FS = 50.0
IMU_ROT_HEADS = 2
MLP_BASE = 128
MLP_DEPTH = 3
DROPOUT = 0.5
LR = 1e-3
WEIGHT_DECAY = 1e-6
USE_ORIG_INPUT = True
USE_BINARY_SELECTION = True
LN_EPS = 1e-7
PAD_MODE = 'SYMMETRIC'

class MLPStack(Layer):
    def __init__(self, base_kn=128, depth=3, drop=0.5, wd=0.0, ln_eps=1e-7, name=None):
        super().__init__(name=name)
        self.base_kn = int(base_kn); self.depth = int(depth)
        self.drop = float(drop); self.wd = float(wd); self.ln_eps = float(ln_eps)
        self.seq = []
        for k in range(self.depth - 1, -1, -1):
            self.seq.append(Dense(self.base_kn * (2**k), kernel_regularizer=l2(self.wd)))
            self.seq.append(LayerNormalization(epsilon=self.ln_eps))
            self.seq.append(LeakyReLU())
            self.seq.append(Dropout(self.drop))
    @property
    def out_dim(self): return self.base_kn
    def call(self, x, training=None):
        z = x
        for lyr in self.seq:
            z = lyr(z, training=training) if isinstance(lyr, Dropout) else lyr(z)
        return z
    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_dim])

TIME_FEATS = 12
FREQ_FEATS = 7

class TSFFeatureLayer(Layer):
    def __init__(self, fs=50.0, use_time=True, use_freq=True, **kwargs):
        super().__init__(**kwargs)
        self.fs = float(fs); self.use_time = bool(use_time); self.use_freq = bool(use_freq); self.eps = 1e-8
        self._feat_dim = (TIME_FEATS if self.use_time else 0) + (FREQ_FEATS if self.use_freq else 0)
    def get_config(self):
        c = super().get_config(); c.update({'fs': self.fs, 'use_time': self.use_time, 'use_freq': self.use_freq}); return c
    def call(self, x):  # [B,L,C]
        feats = []
        if self.use_time:
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            std  = tf.math.reduce_std(x, axis=1, keepdims=True) + self.eps
            maxv = tf.reduce_max(x, axis=1, keepdims=True); minv = tf.reduce_min(x, axis=1, keepdims=True)
            ptp  = maxv - minv; rms = tf.sqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True))
            energy = tf.reduce_sum(tf.square(x), axis=1, keepdims=True)
            skew = tf.reduce_mean(tf.pow((x-mean)/std, 3), axis=1, keepdims=True)
            kurt = tf.reduce_mean(tf.pow((x-mean)/std, 4), axis=1, keepdims=True)
            signs = tf.sign(x); sign_changes = tf.abs(signs[:,1:,:] - signs[:,:-1,:]); zcr = tf.reduce_mean(sign_changes, axis=1, keepdims=True)/2.0
            x_t1 = x[:,:-1,:]; x_tn1 = x[:,1:,:]
            ar1 = tf.reduce_sum(x_t1*x_tn1, axis=1, keepdims=True) / (tf.reduce_sum(tf.square(x_t1), axis=1, keepdims=True) + self.eps)
            x_t2 = x[:,:-2,:]; x_tn2 = x[:,2:,:]
            ar2 = tf.reduce_sum(x_t2*x_tn2, axis=1, keepdims=True) / (tf.reduce_sum(tf.square(x_t2), axis=1, keepdims=True) + self.eps)
            feats += [mean,std,maxv,minv,ptp,rms,energy,skew,kurt,zcr,ar1,ar2]
        if self.use_freq:
            mean = tf.reduce_mean(x, axis=1, keepdims=True); xc = x - mean
            x_bc_t = tf.transpose(xc, [0,2,1]); fft = tf.signal.rfft(x_bc_t); power = tf.square(tf.abs(fft)) + self.eps
            power = tf.transpose(power, [0,2,1]); F = tf.shape(power)[1]
            freqs = tf.linspace(0.0, tf.cast(self.fs, tf.float32)/2.0, F); freqs = tf.reshape(freqs, [1,F,1])
            p = power / (tf.reduce_sum(power, axis=1, keepdims=True) + self.eps)
            centroid = tf.reduce_sum(p * freqs, axis=1, keepdims=True)
            entropy  = -tf.reduce_sum(p * tf.math.log(p + self.eps), axis=1, keepdims=True) / (tf.math.log(tf.cast(F, tf.float32) + self.eps))
            geo = tf.exp(tf.reduce_mean(tf.math.log(power), axis=1, keepdims=True)); ari = tf.reduce_mean(power, axis=1, keepdims=True)
            flatness = geo / (ari + self.eps)
            w = tf.nn.softmax(power * 10.0, axis=1); soft_peak = tf.reduce_sum(w * freqs, axis=1, keepdims=True)
            def band(low, high):
                mask = tf.cast((freqs >= low) & (freqs < high), tf.float32)
                return tf.reduce_sum(power * mask, axis=1, keepdims=True) / (tf.reduce_sum(power, axis=1, keepdims=True) + self.eps)
            bp1 = band(0.5,3.0); bp2 = band(3.0,8.0); bp3 = band(8.0,15.0)
            feats += [centroid,entropy,flatness,soft_peak,bp1,bp2,bp3]
        res = tf.concat(feats, axis=1)                 # [B,Fnum,C]
        return tf.transpose(res, [0,2,1])              # [B,C,Fnum]
    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[2], (TIME_FEATS if self.use_time else 0) + (FREQ_FEATS if self.use_freq else 0)])

class AddL2Channels(Layer):
    def call(self, x, training=None):
        acc = x[:,:,:3]; gyr = x[:,:,3:6]
        l2_acc = tf.sqrt(tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True))
        l2_gyr = tf.sqrt(tf.reduce_sum(tf.square(gyr), axis=-1, keepdims=True))
        return tf.concat([x,l2_acc,l2_gyr], axis=-1)
    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], 8])

def _int_ceil_div(a,b):
    a = tf.cast(a, tf.int32); b = tf.cast(b, tf.int32)
    return tf.math.floordiv(a + b - 1, b)

def frame_signal_with_padding(x, num_blocks, pad_mode='SYMMETRIC'):
    B = tf.shape(x)[0]; T = tf.shape(x)[1]; C = tf.shape(x)[2]
    nb = tf.cast(num_blocks, tf.int32); L = _int_ceil_div(T, nb)
    total = L * nb; pad_len = total - T
    pad_left = tf.math.floordiv(pad_len, 2); pad_right = pad_len - pad_left
    paddings = tf.stack([tf.constant([0,0],tf.int32), tf.stack([pad_left,pad_right]), tf.constant([0,0],tf.int32)], axis=0)
    x_pad = tf.pad(x, paddings, mode=pad_mode)
    return tf.reshape(x_pad, [B, nb, L, C])

class BlockTSFExtractor(Layer):
    def __init__(self, num_blocks, fs, use_time, use_freq, tag_spec=None, pad_mode='SYMMETRIC', name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_blocks = int(num_blocks); self.tsf = TSFFeatureLayer(fs=fs, use_time=use_time, use_freq=use_freq)
        self.tag_spec = tag_spec; self.pad_mode = pad_mode
        self.tag_dim = 0 if (tag_spec is None or 'axis_tags' not in tag_spec) else int(tag_spec['axis_tags'].shape[1])
        self.base_feat_dim = (TIME_FEATS if use_time else 0) + (FREQ_FEATS if use_freq else 0)
        self.out_feat_dim = self.base_feat_dim + self.tag_dim
    def get_config(self):
        c = super().get_config(); c.update({'num_blocks':self.num_blocks, 'fs':self.tsf.fs,
                                            'use_time':self.tsf.use_time,'use_freq':self.tsf.use_freq,
                                            'pad_mode':self.pad_mode}); return c
    def call(self, x, training=None):
        xb = frame_signal_with_padding(x, self.num_blocks, pad_mode=self.pad_mode)  # [B,K,L,C]
        B = tf.shape(xb)[0]; K_ = tf.shape(xb)[1]; L = tf.shape(xb)[2]; C = tf.shape(xb)[3]
        xb2 = tf.reshape(xb, [B*K_, L, C])
        tsf_axis = self.tsf(xb2)
        tsf_axis = tf.reshape(tsf_axis, [B, K_, C, self.base_feat_dim])
        if self.tag_dim > 0:
            axis_tags = tf.convert_to_tensor(self.tag_spec['axis_tags'], dtype=tsf_axis.dtype)
            axis_tags = tf.reshape(axis_tags, [1,1,tf.shape(tsf_axis)[2],-1])
            axis_tags = tf.tile(axis_tags, [B, K_, 1, 1])
            tsf_axis = tf.concat([tsf_axis, axis_tags], axis=-1)
        return tsf_axis
    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.num_blocks, input_shape[2], self.out_feat_dim])

class BinaryGate(Layer):
    def call(self, p, training=None):
        p = tf.clip_by_value(p, 0.0, 1.0)
        hard = tf.round(p)
        return hard + tf.stop_gradient(p - hard)
    def compute_output_shape(self, input_shape):
        return tf.TensorShape(input_shape)

class TSFMixerSubBlock(Layer):
    def __init__(self, axis_hidden=128, out_hidden=128, base_depth=2, drop=0.5, wd=0.0, ln_eps=1e-7, name=None):
        super().__init__(name=name)
        self.axis_hidden=int(axis_hidden); self.out_hidden=int(out_hidden); self.base_depth=int(base_depth)
        self.drop=float(drop); self.wd=float(wd); self.ln_eps=float(ln_eps)
        self.axis_mlp_layers=[]
        for k in range(self.base_depth-1, -1, -1):
            self.axis_mlp_layers.append(Dense(self.axis_hidden*(2**k), kernel_regularizer=l2(self.wd)))
            self.axis_mlp_layers.append(LayerNormalization(epsilon=self.ln_eps))
            self.axis_mlp_layers.append(LeakyReLU())
            self.axis_mlp_layers.append(Dropout(self.drop))
        self.out_stack = MLPStack(base_kn=self.out_hidden, depth=self.base_depth, drop=self.drop, wd=self.wd, ln_eps=self.ln_eps, name=f'{self.name}_out')
    def call(self, x, training=None, **kwargs):
        Bp = tf.shape(x)[0]; A = tf.shape(x)[1]; F = tf.shape(x)[2]
        x2 = tf.reshape(x, [Bp*A, F]); z = x2
        for lyr in self.axis_mlp_layers:
            z = lyr(z, training=training) if isinstance(lyr, Dropout) else lyr(z)
        z = tf.reshape(z, [Bp, A, self.axis_hidden]); z = tf.reshape(z, [Bp, A*self.axis_hidden])
        z = self.out_stack(z, training=training); return z
    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_stack.out_dim])

class TSFMixerBlock(Layer):
    def __init__(self, feat_dim, axis_hidden=128, out_hidden=128, base_depth=2, drop=0.5, wd=0.0, ln_eps=1e-7, use_binary=True, name=None):
        super().__init__(name=name)
        self.use_binary=bool(use_binary)
        self.sub = TSFMixerSubBlock(axis_hidden, out_hidden, base_depth, drop, wd, ln_eps, name=f'{name}_sub')
        self.axis_gate_dense = Dense(1, activation='sigmoid', name=f'{name}_axis_gate')
        self.chan_gate_dense = Dense(int(feat_dim), activation='sigmoid', name=f'{name}_chan_gate')
        self.bin_gate = BinaryGate(name=f'{name}_bin')
        self.out_stack = MLPStack(base_kn=out_hidden, depth=base_depth, drop=drop, wd=wd, ln_eps=ln_eps, name=f'{name}_out')
    def call(self, x, training=None, **kwargs):
        Bp = tf.shape(x)[0]; A = tf.shape(x)[1]; F = tf.shape(x)[2]
        x_mean_axis = tf.reduce_mean(x, axis=1); p_chan = self.chan_gate_dense(x_mean_axis)
        p_chan = tf.reshape(p_chan, [Bp,1,F]); g_chan = self.bin_gate(p_chan, training=training) if self.use_binary else p_chan
        x = x * g_chan
        x2 = tf.reshape(x,[Bp*A,F]); z = x2
        for lyr in self.sub.axis_mlp_layers:
            z = lyr(z, training=training) if isinstance(lyr, Dropout) else lyr(z)
        z = tf.reshape(z,[Bp,A,self.sub.axis_hidden]); p_axis = self.axis_gate_dense(z)
        g_axis = self.bin_gate(p_axis, training=training) if self.use_binary else p_axis
        z = z * g_axis; z = tf.reshape(z,[Bp,A*self.sub.axis_hidden])
        z = self.out_stack(z, training=training); return z
    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_stack.out_dim])

def _feat_dim_for_spec(use_time, use_freq, tag_dim):
    base = (TIME_FEATS if use_time else 0) + (FREQ_FEATS if use_freq else 0)
    return base + tag_dim

BLOCK_SPECS = [dict(name='short', num_blocks=4, use_time=True,  use_freq=False),
               dict(name='long',  num_blocks=1, use_time=False, use_freq=True)]

class RotationParamEstimator(Layer):
    def __init__(self, block_specs, fs, mlp_base=128, mlp_depth=2, drop=0.5, wd=0.0, ln_eps=1e-7, use_binary=True, pad_mode='SYMMETRIC', name=None):
        super().__init__(name=name)
        self.block_specs=block_specs; self.fs=fs; self.mlp_base=int(mlp_base); self.mlp_depth=int(mlp_depth)
        self.drop=float(drop); self.wd=float(wd); self.ln_eps=float(ln_eps); self.use_binary=bool(use_binary); self.pad_mode=pad_mode
        axis_tags=[];
        for i in range(8):
            axis_type=i+1; sensor_type=1 if (i<=2 or i==6) else 2
            axis_tags.append([axis_type, sensor_type])
        axis_tags=np.array(axis_tags,dtype=np.float32)
        self.tag_spec={'axis_tags':axis_tags}; tag_dim=axis_tags.shape[1]
        self.extractors=[]; self.td_mixers=[]; self.flatteners=[]
        for spec in block_specs:
            ext=BlockTSFExtractor(num_blocks=spec['num_blocks'], fs=fs, use_time=spec['use_time'], use_freq=spec['use_freq'], tag_spec=self.tag_spec, pad_mode=self.pad_mode, name=f'rot_ext_{spec["name"]}')
            self.extractors.append(ext); feat_dim=_feat_dim_for_spec(spec['use_time'], spec['use_freq'], tag_dim)
            mix=TSFMixerBlock(feat_dim=feat_dim, axis_hidden=self.mlp_base, out_hidden=self.mlp_base, base_depth=max(1,self.mlp_depth-1), drop=self.drop, wd=self.wd, ln_eps=self.ln_eps, use_binary=self.use_binary, name=f'rot_mix_{spec["name"]}')
            self.td_mixers.append(TimeDistributed(mix, name=f'rot_td_{spec["name"]}')); self.flatteners.append(Flatten(name=f'rot_flat_{spec["name"]}'))
        self.concat_sets=Concatenate(name='rot_concat_sets')
        self.post_stack=MLPStack(base_kn=self.mlp_base, depth=self.mlp_depth, drop=self.drop, wd=self.wd, ln_eps=self.ln_eps, name='rot_post')
        self.out_head=Dense(4, activation='tanh', name='rot4_tanh'); self.add_l2=AddL2Channels()
    def call(self, x, training=None, **kwargs):
        x8=self.add_l2(x); feats_all=[]
        for ext,td,flt in zip(self.extractors,self.td_mixers,self.flatteners):
            tsf_blocks=ext(x8, training=training); blk_feat=td(tsf_blocks, training=training); blk_feat=flt(blk_feat); feats_all.append(blk_feat)
        h=self.concat_sets(feats_all); h=self.post_stack(h, training=training); rot4=self.out_head(h); return rot4
    def compute_output_shape(self, input_shape): return tf.TensorShape([input_shape[0], 4])

class Multihead3DRotationOfficial(Layer):
    def __init__(self, head_nums=2, fs=50.0, mlp_base=128, mlp_depth=2, drop=0.5, wd=0.0, ln_eps=1e-7, block_specs=None, use_binary=True, pad_mode='SYMMETRIC', name=None):
        super().__init__(name=name); block_specs = BLOCK_SPECS if block_specs is None else block_specs
        self.head_nums=int(head_nums)
        self.estimator=RotationParamEstimator(block_specs=block_specs, fs=fs, mlp_base=mlp_base, mlp_depth=mlp_depth, drop=drop, wd=wd, ln_eps=ln_eps, use_binary=use_binary, pad_mode=pad_mode, name='rot_estimator')
        self.eps=1e-8
    def compute_output_shape(self, input_shape): return [tf.TensorShape(input_shape) for _ in range(self.head_nums)]
    def _axis_angle_to_R(self, axis_raw, angle_raw):
        axis=axis_raw/(tf.norm(axis_raw,axis=-1,keepdims=True)+self.eps); theta=angle_raw*math.pi
        B=tf.shape(axis)[0]; ux,uy,uz=axis[:,0],axis[:,1],axis[:,2]; z=tf.zeros_like(ux)
        K=tf.stack([z,-uz,uy,uz,z,-ux,-uy,ux,z],axis=-1); K=tf.reshape(K,[B,3,3])
        I=tf.tile(tf.eye(3,dtype=axis.dtype)[None,...],[B,1,1]); u=tf.expand_dims(axis,-1); uuT=tf.matmul(u,u,transpose_b=True)
        cos=tf.reshape(tf.cos(theta),[-1,1,1]); sin=tf.reshape(tf.sin(theta),[-1,1,1]); R=cos*I+(1.0-cos)*uuT+sin*K; return R
    def call(self, x, training=None, **kwargs):
        acc,gyr=x[:,:,:3],x[:,:,3:6]; out_list=[]; prev_rot4=None
        for _ in range(self.head_nums):
            rot4=self.estimator(x, training=training)
            if prev_rot4 is not None: rot4=rot4+prev_rot4
            prev_rot4=rot4; axis=rot4[:,:3]; angle=tf.expand_dims(rot4[:,3],-1); R=self._axis_angle_to_R(axis,angle)
            acc_t=tf.transpose(acc,[0,2,1]); acc_rot=tf.transpose(tf.matmul(R,acc_t),[0,2,1])
            gyr_t=tf.transpose(gyr,[0,2,1]); gyr_rot=tf.transpose(tf.matmul(R,gyr_t),[0,2,1])
            out_list.append(tf.concat([acc_rot,gyr_rot],axis=-1))
        return out_list

class AddL2ChannelsPublic(Layer):
    def call(self, x, training=None):
        acc=x[:,:,:3]; gyr=x[:,:,3:6]
        l2_acc=tf.sqrt(tf.reduce_sum(tf.square(acc),axis=-1,keepdims=True))
        l2_gyr=tf.sqrt(tf.reduce_sum(tf.square(gyr),axis=-1,keepdims=True))
        return tf.concat([x,l2_acc,l2_gyr],axis=-1)

def r_tsf_net_official(x_shape, n_classes, learning_rate=LR, base_kn=MLP_BASE, depth=MLP_DEPTH, dropout_rate=DROPOUT,
                       imu_rot_heads=IMU_ROT_HEADS, fs=FS, use_orig_input=USE_ORIG_INPUT, use_binary_selection=USE_BINARY_SELECTION,
                       ln_eps=LN_EPS, pad_mode=PAD_MODE):
    inputs=Input(shape=x_shape[1:]); x=inputs
    rot_layer=Multihead3DRotationOfficial(head_nums=imu_rot_heads, fs=fs, mlp_base=base_kn, mlp_depth=max(1,depth-1),
                                          drop=dropout_rate, wd=WEIGHT_DECAY, ln_eps=ln_eps, block_specs=BLOCK_SPECS,
                                          use_binary=use_binary_selection, pad_mode=pad_mode, name='multihead_rot_official')
    rotated_list=rot_layer(x)
    streams=[]; add_l2=AddL2ChannelsPublic()
    if use_orig_input: streams.append(add_l2(x))
    for xr in rotated_list: streams.append(add_l2(xr))
    concat_streams=Concatenate(axis=-1,name='concat_streams')(streams)
    axis_tags_one_stream=[[i+1, 1 if (i<=2 or i==6) else 2] for i in range(8)]
    axis_tags_one_stream=np.array(axis_tags_one_stream,dtype=np.float32)
    num_streams=(1 if use_orig_input else 0)+imu_rot_heads
    axis_tags_all=np.concatenate([axis_tags_one_stream for _ in range(num_streams)],axis=0)
    tag_spec_main={'axis_tags':axis_tags_all}; tag_dim_main=axis_tags_all.shape[1]
    feats_all_sets=[]
    for spec in BLOCK_SPECS:
        ext=BlockTSFExtractor(num_blocks=spec['num_blocks'], fs=fs, use_time=spec['use_time'], use_freq=spec['use_freq'],
                              tag_spec=tag_spec_main, pad_mode=pad_mode, name=f'main_ext_{spec["name"]}')
        feat_dim=_feat_dim_for_spec(spec['use_time'], spec['use_freq'], tag_dim_main)
        mix=TSFMixerBlock(feat_dim=feat_dim, axis_hidden=base_kn, out_hidden=base_kn, base_depth=max(1,depth-1),
                          drop=dropout_rate, wd=WEIGHT_DECAY, ln_eps=ln_eps, use_binary=use_binary_selection, name=f'main_mix_{spec["name"]}')
        td=TimeDistributed(mix, name=f'main_td_{spec["name"]}'); flt=Flatten(name=f'main_flat_{spec["name"]}')
        tsf_blocks=ext(concat_streams); blk_feat=td(tsf_blocks); blk_feat=flt(blk_feat); feats_all_sets.append(blk_feat)
    z=Concatenate(name='main_concat_sets')(feats_all_sets)
    cls_stack=MLPStack(base_kn=base_kn, depth=depth, drop=dropout_rate, wd=WEIGHT_DECAY, ln_eps=ln_eps, name='cls')
    z=cls_stack(z); logits=Dense(n_classes, kernel_regularizer=l2(WEIGHT_DECAY), name='logits')(z)
    probs=Activation('softmax', dtype='float32', name='softmax')(logits)
    model=Model(inputs, probs, name='rTsfNet_official_aligned'); opt=Adam(learning_rate=learning_rate, amsgrad=True)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy']); return model

# ---------------- Data: prefer /content/features; otherwise synthetic demo ----------------
BASE=Path('/content'); features_dir=BASE/'features'; models_dir=BASE/'models'; logs_dir=BASE/'logs'
models_dir.mkdir(parents=True, exist_ok=True); logs_dir.mkdir(parents=True, exist_ok=True)

def load_fold_data(fold_k, features_dir: Path):
    npz_file=features_dir/f'windows_normalized_fold{fold_k}.npz'
    data=np.load(npz_file, allow_pickle=True)
    X=np.stack([data['acc_x'],data['acc_y'],data['acc_z'],data['gyro_x'],data['gyro_y'],data['gyro_z']],axis=-1)
    y=data['labels']; splits=data['splits']; train_mask=splits=='train'; test_mask=splits=='test'
    return X[train_mask],y[train_mask],X[test_mask],y[test_mask]

def scan_available_folds(features_dir: Path):
    ks=[];
    for f in features_dir.glob("windows_normalized_fold*.npz"):
        try: ks.append(int(f.stem.replace("windows_normalized_fold","")))
        except Exception: pass
    return sorted(set(ks))

def make_synth_fold(n_train=4000, n_test=800, T=150, C=6, n_classes=8, seed=2025):
    rng=np.random.default_rng(seed)
    Xtr=rng.normal(0,1,size=(n_train,T,C)).astype(np.float32); Xte=rng.normal(0,1,size=(n_test,T,C)).astype(np.float32)
    ytr=rng.integers(0,n_classes,size=n_train).astype(np.int64); yte=rng.integers(0,n_classes,size=n_test).astype(np.int64)
    return Xtr,ytr,Xte,yte

available_folds=scan_available_folds(features_dir)
ACTIVE_FOLDS=available_folds[:1] if available_folds else [0]
print(f"[Info] Available folds: {available_folds} | Planned measurement: {ACTIVE_FOLDS}")

# ---------------- TF inference wrapper (full test set = one logical call; unit = window) ----------------
for g in tf.config.list_physical_devices('GPU'):
    try: tf.config.experimental.set_memory_growth(g, True)
    except Exception: pass

def make_tf_runner(model: tf.keras.Model, X_test_np: np.ndarray, bs: int = 256):
    device="/GPU:0" if tf.config.list_physical_devices('GPU') else "/CPU:0"
    with tf.device(device): X_gpu=tf.convert_to_tensor(X_test_np.astype(np.float32))
    N=X_test_np.shape[0]
    @tf.function(jit_compile=False)
    def fwd(x): return model(x, training=False)
    def run_once():
        last=None
        for s in range(0,N,bs):
            e=min(N,s+bs); last=fwd(X_gpu[s:e])
        _=tf.reduce_sum(last).numpy()
    return run_once, N  # N windows per call

# ---------------- Idle power ----------------
print("\n[Info] Measuring idle power for 20 s ...")
P_idle_mW,_idle=sample_idle_power_mW(duration_s=20.0, dev_index=0, interval=0.02,
                                     save_csv=str(logs_dir/'power_idle_trace_rtsfnet_official.csv'))
print(f"[Info] Mean idle power ~ {P_idle_mW:.1f} mW")

# ---------------- Per-fold measurement (per-window metrics) ----------------
summary_rows=[]
for k in ACTIVE_FOLDS:
    print("\n"+"="*72)
    print(f"[rTsfNet-Official] Fold {k} — preparing data and model (original architecture/hyperparameters)")

    if k in available_folds:
        X_train,y_train,X_test,y_test=load_fold_data(k, features_dir); n_classes=int(max(y_train.max(), y_test.max())+1)
    else:
        print("[Warn] Real feature files not found; using synthetic data for demonstration.")
        X_train,y_train,X_test,y_test=make_synth_fold(); n_classes=int(max(y_train.max(), y_test.max())+1)

    # Derive window_seconds from data length and FS (explicit reporting)
    window_seconds=float(X_test.shape[1]/FS)

    model=r_tsf_net_official(x_shape=X_train.shape, n_classes=n_classes,
                             learning_rate=LR, base_kn=MLP_BASE, depth=MLP_DEPTH, dropout_rate=DROPOUT,
                             imu_rot_heads=IMU_ROT_HEADS, fs=FS, use_orig_input=USE_ORIG_INPUT,
                             use_binary_selection=USE_BINARY_SELECTION, ln_eps=LN_EPS, pad_mode=PAD_MODE)

    wpath=models_dir/f"model_fold{k}.weights.h5"
    if wpath.exists():
        try: model.load_weights(wpath); print(f"[Info] Loaded weights: {wpath.name}")
        except Exception as e: print(f"[Warn] Failed to load weights: {e}")

    # Optional sanity check (excluded from energy measurement)
    try:
        y_prob=model.predict(X_test, batch_size=256, verbose=0); acc=(y_prob.argmax(1)==y_test).mean()
        print(f"[Check] Fold {k} quick accuracy: {acc:.3f}")
    except Exception as e:
        print(f"[Warn] Skipping accuracy check: {e}")

    run_once,N_windows_per_call=make_tf_runner(model, X_test, bs=256)

    for _ in range(3): run_once()
    repeats=calibrate_repeats(run_once, target_s=8.0, min_rep=3, max_rep=5000)
    print(f"[Info] repeats = {repeats}  (windows per call = {N_windows_per_call})")

    tag=f"rtsfnet_official_fold{k}_per_window"
    summ=measure_with_bootstrap(name=tag, run_once=run_once, n_windows=N_windows_per_call,
                                repeats=repeats, n_runs=5, n_boot=1000, logdir=logs_dir)

    summary_rows.append({
        "fold": k,
        "model": f"rTsfNet-Official (fold {k})",
        "window_seconds": window_seconds,
        "mJ_per_window_mean": summ["mean_mJ_per_window"],
        "ci95_low_mJ": summ["ci95_low_mJ"],
        "ci95_high_mJ": summ["ci95_high_mJ"],
        "ms_per_window_mean": summ["mean_ms_per_window"],
    })

    K.clear_session(); gc.collect()

# ---------------- Summary output ----------------
df_sum=pd.DataFrame(summary_rows).sort_values("fold").reset_index(drop=True)
df_sum.to_csv(logs_dir/"energy_summary_rtsfnet_official_per_window.csv", index=False)
print("\n=== Completed (rTsfNet official × Option 1 · per-window energy & latency) ===")
print(df_sum)
print("\nLog files:")
print("- logs/power_idle_trace_rtsfnet_official.csv")
print("- logs/power_trace_rtsfnet_official_fold*_per_window_run*.csv")
print("- logs/energy_rtsfnet_official_fold*_per_window.json")
print("- logs/energy_summary_rtsfnet_official_per_window.csv")

Mon Nov 17 18:39:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   32C    P0             60W /  400W |   26271MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                