In [6]:
# ============ Official rTsfNet (IMWUT 2024, TSF-Mixer + multi-head rotation) – structure & size ============

import json
import math
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import (
    Dense, Dropout, LayerNormalization, LeakyReLU,
    Layer, Activation, TimeDistributed, Flatten, Concatenate,
    GlobalAveragePooling1D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

print("\n[Official rTsfNet (IMWUT 2024) – structure & size]")

# ---------------------------
# 1) Load num_classes and window length if configs exist
# ---------------------------
BASE = Path("/content")
CFG_DIR = BASE / "configs"

if (CFG_DIR / "classes.json").exists():
    with open(CFG_DIR / "classes.json", "r") as f:
        classes_cfg = json.load(f)
    NUM_CLASSES = int(classes_cfg["num_classes"])
    if "window_config" in classes_cfg and "window_samples" in classes_cfg["window_config"]:
        WINDOW_SAMPLES = int(classes_cfg["window_config"]["window_samples"])
    else:
        WINDOW_SAMPLES = 150
    print(f"Detected NUM_CLASSES = {NUM_CLASSES}, WINDOW_SAMPLES = {WINDOW_SAMPLES} from configs.")
else:
    NUM_CLASSES = 8
    WINDOW_SAMPLES = 150
    print("Warning: /content/configs/classes.json not found.")
    print(f"Using defaults: NUM_CLASSES = {NUM_CLASSES}, WINDOW_SAMPLES = {WINDOW_SAMPLES}.")

# ---------------------------
# 2) Hyperparameters (must match your official rTsfNet script)
# ---------------------------
FS                 = 50.0
IMU_ROT_HEADS      = 2
MLP_BASE           = 128
MLP_DEPTH          = 3
DROPOUT            = 0.5
LR                 = 1e-3
WEIGHT_DECAY       = 1e-6
USE_ORIG_INPUT     = True
USE_BINARY_SELECTION = True
LN_EPS             = 1e-7
PAD_MODE           = "SYMMETRIC"

BLOCK_SPECS = [
    dict(name="short", num_blocks=4, use_time=True,  use_freq=False),
    dict(name="long",  num_blocks=1, use_time=False, use_freq=True),
]

TIME_FEATS = 12  # mean/std/max/min/ptp/rms/energy/skew/kurt/zcr/ar1/ar2
FREQ_FEATS = 7   # centroid/entropy/flatness/soft-peak/bandpowers(3)

print(f"\nConfig for size check:")
print(f"  NUM_CLASSES         = {NUM_CLASSES}")
print(f"  WINDOW_SAMPLES      = {WINDOW_SAMPLES}")
print(f"  IMU_ROT_HEADS       = {IMU_ROT_HEADS}")
print(f"  MLP_BASE            = {MLP_BASE}")
print(f"  MLP_DEPTH           = {MLP_DEPTH}")
print(f"  DROPOUT             = {DROPOUT}")
print(f"  WEIGHT_DECAY        = {WEIGHT_DECAY}")
print(f"  USE_ORIG_INPUT      = {USE_ORIG_INPUT}")
print(f"  USE_BINARY_SELECTION= {USE_BINARY_SELECTION}")
print(f"  PAD_MODE            = {PAD_MODE}")

# ---------------------------
# 3) Utility MLP stack (shared)
# ---------------------------
class MLPStack(Layer):
    """
    Dense -> LayerNorm -> LeakyReLU -> Dropout repeated 'depth' times,
    hidden width base_kn * (2**k), k: depth-1..0; output dimensionality is base_kn.
    """
    def __init__(self, base_kn=128, depth=3, drop=0.5, wd=0.0, ln_eps=1e-7, name=None):
        super().__init__(name=name)
        self.base_kn = int(base_kn)
        self.depth = int(depth)
        self.drop = float(drop)
        self.wd = float(wd)
        self.ln_eps = float(ln_eps)

        self.seq = []
        for k in range(self.depth - 1, -1, -1):
            self.seq.append(Dense(self.base_kn * (2**k), kernel_regularizer=l2(self.wd)))
            self.seq.append(LayerNormalization(epsilon=self.ln_eps))
            self.seq.append(LeakyReLU())
            self.seq.append(Dropout(self.drop))

    @property
    def out_dim(self):
        return self.base_kn

    def call(self, x, training=None):
        z = x
        for lyr in self.seq:
            if isinstance(lyr, Dropout):
                z = lyr(z, training=training)
            else:
                z = lyr(z)
        return z

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_dim])

# ---------------------------
# 4) TSF feature layer (axis-wise)
# ---------------------------
class TSFFeatureLayer(Layer):
    """
    Compute axis-wise TSF features for a single block [B, L, C];
    output shape [B, C, F] where F is the TSF feature dimensionality.
    """
    def __init__(self, fs=50.0, use_time=True, use_freq=True, **kwargs):
        super().__init__(**kwargs)
        self.fs = float(fs)
        self.use_time = bool(use_time)
        self.use_freq = bool(use_freq)
        self.eps = 1e-8
        self._feat_dim = (TIME_FEATS if self.use_time else 0) + (FREQ_FEATS if self.use_freq else 0)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({'fs': self.fs, 'use_time': self.use_time, 'use_freq': self.use_freq})
        return cfg

    def call(self, x):  # x: [B, L, C]
        feats = []
        if self.use_time:
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            std  = tf.math.reduce_std(x, axis=1, keepdims=True) + self.eps
            maxv = tf.reduce_max(x, axis=1, keepdims=True)
            minv = tf.reduce_min(x, axis=1, keepdims=True)
            ptp  = maxv - minv
            rms  = tf.sqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True))
            energy = tf.reduce_sum(tf.square(x), axis=1, keepdims=True)
            skew = tf.reduce_mean(tf.pow((x - mean) / std, 3), axis=1, keepdims=True)
            kurt = tf.reduce_mean(tf.pow((x - mean) / std, 4), axis=1, keepdims=True)
            signs = tf.sign(x)
            sign_changes = tf.abs(signs[:, 1:, :] - signs[:, :-1, :])
            zcr = tf.reduce_mean(sign_changes, axis=1, keepdims=True) / 2.0
            x_t1 = x[:, :-1, :]; x_tn1 = x[:, 1:, :]
            ar1 = tf.reduce_sum(x_t1 * x_tn1, axis=1, keepdims=True) / (
                tf.reduce_sum(tf.square(x_t1), axis=1, keepdims=True) + self.eps
            )
            x_t2 = x[:, :-2, :]; x_tn2 = x[:, 2:, :]
            ar2 = tf.reduce_sum(x_t2 * x_tn2, axis=1, keepdims=True) / (
                tf.reduce_sum(tf.square(x_t2), axis=1, keepdims=True) + self.eps
            )
            feats += [mean, std, maxv, minv, ptp, rms, energy, skew, kurt, zcr, ar1, ar2]

        if self.use_freq:
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            xc = x - mean
            x_bc_t = tf.transpose(xc, [0, 2, 1])          # [B, C, L]
            fft = tf.signal.rfft(x_bc_t)                  # [B, C, F]
            power = tf.square(tf.abs(fft)) + self.eps     # [B, C, F]
            power = tf.transpose(power, [0, 2, 1])        # [B, F, C]

            F = tf.shape(power)[1]
            freqs = tf.linspace(0.0, tf.cast(self.fs, tf.float32) / 2.0, F)
            freqs = tf.reshape(freqs, [1, F, 1])          # [1, F, 1]

            p = power / (tf.reduce_sum(power, axis=1, keepdims=True) + self.eps)
            centroid = tf.reduce_sum(p * freqs, axis=1, keepdims=True)
            entropy  = -tf.reduce_sum(p * tf.math.log(p + self.eps), axis=1, keepdims=True) / (
                tf.math.log(tf.cast(F, tf.float32) + self.eps)
            )
            geo = tf.exp(tf.reduce_mean(tf.math.log(power), axis=1, keepdims=True))
            ari = tf.reduce_mean(power, axis=1, keepdims=True)
            flatness = geo / (ari + self.eps)
            temp = 10.0
            w = tf.nn.softmax(power * temp, axis=1)
            soft_peak = tf.reduce_sum(w * freqs, axis=1, keepdims=True)

            def band(low, high):
                mask = tf.cast((freqs >= low) & (freqs < high), tf.float32)
                bp = tf.reduce_sum(power * mask, axis=1, keepdims=True) / (
                    tf.reduce_sum(power, axis=1, keepdims=True) + self.eps
                )
                return bp

            bp1 = band(0.5, 3.0); bp2 = band(3.0, 8.0); bp3 = band(8.0, 15.0)
            feats += [centroid, entropy, flatness, soft_peak, bp1, bp2, bp3]

        res = tf.concat(feats, axis=1)                    # [B, Fnum, C]
        return tf.transpose(res, [0, 2, 1])               # [B, C, Fnum]

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[2], self._feat_dim])

# ---------------------------
# 5) L2 channels layers
# ---------------------------
class AddL2Channels(Layer):
    def call(self, x, training=None):
        acc = x[:, :, :3]
        gyr = x[:, :, 3:6]
        l2_acc = tf.sqrt(tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True))
        l2_gyr = tf.sqrt(tf.reduce_sum(tf.square(gyr), axis=-1, keepdims=True))
        return tf.concat([x, l2_acc, l2_gyr], axis=-1)  # [B, T, 8]

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], 8])


class AddL2ChannelsPublic(Layer):
    def call(self, x, training=None):
        acc = x[:, :, :3]
        gyr = x[:, :, 3:6]
        l2_acc = tf.sqrt(tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True))
        l2_gyr = tf.sqrt(tf.reduce_sum(tf.square(gyr), axis=-1, keepdims=True))
        return tf.concat([x, l2_acc, l2_gyr], axis=-1)  # [B, T, 8]

# ---------------------------
# 6) Block framing (Keras 3 safe)
# ---------------------------
def _int_ceil_div(a, b):
    a = tf.cast(a, tf.int32); b = tf.cast(b, tf.int32)
    return tf.math.floordiv(a + b - 1, b)

def frame_signal_with_padding(x, num_blocks, pad_mode='SYMMETRIC'):
    """
    x: [B, T, C] -> symmetric padding to length L * num_blocks
    and reshape to [B, num_blocks, L, C].
    """
    B = tf.shape(x)[0]; T = tf.shape(x)[1]; C = tf.shape(x)[2]
    nb = tf.cast(num_blocks, tf.int32)
    L  = _int_ceil_div(T, nb)
    total = L * nb
    pad_len = total - T
    pad_left  = tf.math.floordiv(pad_len, 2)
    pad_right = pad_len - pad_left
    paddings = tf.stack([
        tf.constant([0, 0], dtype=tf.int32),
        tf.stack([pad_left, pad_right]),
        tf.constant([0, 0], dtype=tf.int32)
    ], axis=0)
    x_pad = tf.pad(x, paddings, mode=pad_mode)
    x_blocks = tf.reshape(x_pad, [B, nb, L, C])
    return x_blocks

# ---------------------------
# 7) TSF block extractor
# ---------------------------
def _feat_dim_for_spec(use_time, use_freq, tag_dim):
    base = (TIME_FEATS if use_time else 0) + (FREQ_FEATS if use_freq else 0)
    return base + tag_dim

class BlockTSFExtractor(Layer):
    """
    Apply TSF extraction and axis-tag injection for a block set.
    Input:  x with shape [B, T, C]
    Output: TSF tensor [B, num_blocks, A, F_total] (A = C; F_total includes tags).
    """
    def __init__(self, num_blocks, fs, use_time, use_freq,
                 tag_spec=None, pad_mode='SYMMETRIC', name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_blocks = int(num_blocks)
        self.tsf = TSFFeatureLayer(fs=fs, use_time=use_time, use_freq=use_freq)
        self.tag_spec = tag_spec
        self.pad_mode = pad_mode
        self.tag_dim = 0 if (tag_spec is None or 'axis_tags' not in tag_spec) else int(tag_spec['axis_tags'].shape[1])
        self.base_feat_dim = (TIME_FEATS if use_time else 0) + (FREQ_FEATS if use_freq else 0)
        self.out_feat_dim = self.base_feat_dim + self.tag_dim

    def get_config(self):
        cfg = super().get_config()
        cfg.update({'num_blocks': self.num_blocks, 'fs': self.tsf.fs,
                    'use_time': self.tsf.use_time, 'use_freq': self.tsf.use_freq,
                    'pad_mode': self.pad_mode})
        return cfg

    def call(self, x, training=None):  # x: [B, T, C]
        xb = frame_signal_with_padding(x, self.num_blocks, pad_mode=self.pad_mode)  # [B, K, L, C]
        B = tf.shape(xb)[0]; K = tf.shape(xb)[1]; L = tf.shape(xb)[2]; C = tf.shape(xb)[3]
        xb2 = tf.reshape(xb, [B * K, L, C])                        # [B*K, L, C]
        tsf_axis = self.tsf(xb2)                                   # [B*K, C, F]
        tsf_axis = tf.reshape(tsf_axis, [B, K, C, self.base_feat_dim])  # [B, K, A, F_base]

        if self.tag_dim > 0:
            axis_tags = tf.convert_to_tensor(self.tag_spec['axis_tags'], dtype=tsf_axis.dtype)  # [A, tag_dim]
            axis_tags = tf.reshape(axis_tags, [1, 1, tf.shape(tsf_axis)[2], -1])               # [1,1,A,tag_dim]
            axis_tags = tf.tile(axis_tags, [B, K, 1, 1])                                       # [B,K,A,tag_dim]
            tsf_axis = tf.concat([tsf_axis, axis_tags], axis=-1)                               # [B,K,A,F_base+tag_dim]
        return tsf_axis

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.num_blocks, input_shape[2], self.out_feat_dim])

# ---------------------------
# 8) Binary gate (STE)
# ---------------------------
class BinaryGate(Layer):
    def call(self, p, training=None):
        p = tf.clip_by_value(p, 0.0, 1.0)
        hard = tf.round(p)
        return hard + tf.stop_gradient(p - hard)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(input_shape)

# ---------------------------
# 9) TSF-Mixer sub-block and block
# ---------------------------
class TSFMixerSubBlock(Layer):
    """
    Input: per-block axis-level TSF features [B', A, F]
    Axis-shared MLP -> concat axes -> MLP -> block feature
    """
    def __init__(self, axis_hidden=128, out_hidden=128, base_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7, name=None):
        super().__init__(name=name)
        self.axis_hidden = int(axis_hidden)
        self.out_hidden = int(out_hidden)
        self.base_depth = int(base_depth)
        self.drop = float(drop); self.wd = float(wd); self.ln_eps = float(ln_eps)

        self.axis_mlp_layers = []
        for k in range(self.base_depth - 1, -1, -1):
            self.axis_mlp_layers.append(Dense(self.axis_hidden * (2**k), kernel_regularizer=l2(self.wd)))
            self.axis_mlp_layers.append(LayerNormalization(epsilon=self.ln_eps))
            self.axis_mlp_layers.append(LeakyReLU())
            self.axis_mlp_layers.append(Dropout(self.drop))

        self.out_stack = MLPStack(base_kn=self.out_hidden, depth=self.base_depth,
                                  drop=self.drop, wd=self.wd, ln_eps=self.ln_eps,
                                  name=f'{self.name}_out')

    def call(self, x, training=None, **kwargs):  # x: [B', A, F]
        Bp = tf.shape(x)[0]; A = tf.shape(x)[1]; F = tf.shape(x)[2]
        x2 = tf.reshape(x, [Bp * A, F])
        z = x2
        for lyr in self.axis_mlp_layers:
            if isinstance(lyr, Dropout):
                z = lyr(z, training=training)
            else:
                z = lyr(z)
        z = tf.reshape(z, [Bp, A, self.axis_hidden])      # [B', A, H_axis]
        z = tf.reshape(z, [Bp, A * self.axis_hidden])     # [B', A*H_axis]
        z = self.out_stack(z, training=training)          # [B', H_out]
        return z

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_stack.out_dim])


class TSFMixerBlock(Layer):
    """
    Extends the sub-Block with:
      - channel-wise binary selection over feature dim
      - axis-wise binary selection over axis dim
    """
    def __init__(self, feat_dim, axis_hidden=128, out_hidden=128, base_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7, use_binary=True, name=None):
        super().__init__(name=name)
        self.use_binary = bool(use_binary)
        self.sub = TSFMixerSubBlock(axis_hidden, out_hidden, base_depth, drop, wd, ln_eps,
                                    name=f'{name}_sub')

        self.axis_gate_dense = Dense(1, activation='sigmoid', name=f'{name}_axis_gate')
        self.chan_gate_dense = Dense(int(feat_dim), activation='sigmoid', name=f'{name}_chan_gate')
        self.bin_gate = BinaryGate(name=f'{name}_bin')
        self.out_stack = MLPStack(base_kn=out_hidden, depth=base_depth,
                                  drop=drop, wd=wd, ln_eps=ln_eps, name=f'{name}_out')

    def call(self, x, training=None, **kwargs):  # x: [B', A, F]
        Bp = tf.shape(x)[0]; A = tf.shape(x)[1]; F = tf.shape(x)[2]

        x_mean_axis = tf.reduce_mean(x, axis=1)         # [B', F]
        p_chan = self.chan_gate_dense(x_mean_axis)      # [B', F]
        p_chan = tf.reshape(p_chan, [Bp, 1, F])
        g_chan = self.bin_gate(p_chan, training=training) if self.use_binary else p_chan
        x = x * g_chan

        x2 = tf.reshape(x, [Bp * A, F])
        z = x2
        for lyr in self.sub.axis_mlp_layers:
            if isinstance(lyr, Dropout):
                z = lyr(z, training=training)
            else:
                z = lyr(z)
        z = tf.reshape(z, [Bp, A, self.sub.axis_hidden])   # [B', A, H_axis]

        p_axis = self.axis_gate_dense(z)                   # [B', A, 1]
        g_axis = self.bin_gate(p_axis, training=training) if self.use_binary else p_axis
        z = z * g_axis

        z = tf.reshape(z, [Bp, A * self.sub.axis_hidden])  # [B', A*H_axis]
        z = self.out_stack(z, training=training)           # [B', H_out]
        return z

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_stack.out_dim])

# ---------------------------
# 10) Rotation parameter estimator
# ---------------------------
class RotationParamEstimator(Layer):
    """
    Input: [B, T, 6] raw IMU; internally append L2 channels,
    extract TSF (for multiple block sets), TSF-Mixer -> concat -> MLP -> Dense(4, tanh).
    """
    def __init__(self, block_specs, fs, mlp_base=128, mlp_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7,
                 use_binary=True, pad_mode='SYMMETRIC', name=None):
        super().__init__(name=name)
        self.block_specs = block_specs
        self.fs = fs
        self.mlp_base = int(mlp_base)
        self.mlp_depth = int(mlp_depth)
        self.drop = float(drop)
        self.wd = float(wd)
        self.ln_eps = float(ln_eps)
        self.use_binary = bool(use_binary)
        self.pad_mode = pad_mode

        axis_tags = []
        for i in range(8):
            axis_type = i + 1
            sensor_type = 1 if (i <= 2 or i == 6) else 2
            axis_tags.append([axis_type, sensor_type])
        axis_tags = np.array(axis_tags, dtype=np.float32)
        self.tag_spec = {'axis_tags': axis_tags}
        tag_dim = axis_tags.shape[1]

        self.extractors = []
        self.td_mixers  = []
        self.flatteners = []
        for spec in block_specs:
            ext = BlockTSFExtractor(num_blocks=spec['num_blocks'], fs=fs,
                                    use_time=spec['use_time'], use_freq=spec['use_freq'],
                                    tag_spec=self.tag_spec, pad_mode=self.pad_mode,
                                    name=f'rot_ext_{spec["name"]}')
            self.extractors.append(ext)
            feat_dim = _feat_dim_for_spec(spec['use_time'], spec['use_freq'], tag_dim)
            mix = TSFMixerBlock(feat_dim=feat_dim, axis_hidden=self.mlp_base,
                                out_hidden=self.mlp_base,
                                base_depth=max(1, self.mlp_depth - 1),
                                drop=self.drop, wd=self.wd,
                                ln_eps=self.ln_eps, use_binary=self.use_binary,
                                name=f'rot_mix_{spec["name"]}')
            self.td_mixers.append(TimeDistributed(mix, name=f'rot_td_{spec["name"]}'))
            self.flatteners.append(Flatten(name=f'rot_flat_{spec["name"]}'))

        self.concat_sets = Concatenate(name='rot_concat_sets')
        self.post_stack = MLPStack(base_kn=self.mlp_base, depth=self.mlp_depth,
                                   drop=self.drop, wd=self.wd, ln_eps=self.ln_eps,
                                   name='rot_post')
        self.out_head = Dense(4, activation='tanh', name='rot4_tanh')
        self.add_l2 = AddL2Channels()

    def call(self, x, training=None, **kwargs):  # x: [B, T, 6]
        x8 = self.add_l2(x)  # [B, T, 8]
        feats_all = []
        for ext, td, flt in zip(self.extractors, self.td_mixers, self.flatteners):
            tsf_blocks = ext(x8, training=training)        # [B, K, A, F]
            blk_feat   = td(tsf_blocks, training=training) # [B, K, H]
            blk_feat   = flt(blk_feat)                     # [B, K*H]
            feats_all.append(blk_feat)
        h = self.concat_sets(feats_all)                    # [B, sum(K*H)]
        h = self.post_stack(h, training=training)
        rot4 = self.out_head(h)                            # [B, 4]
        return rot4

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], 4])

# ---------------------------
# 11) Multi-head 3D rotation (official)
# ---------------------------
class Multihead3DRotationOfficial(Layer):
    """
    Input [B, T, 6]; output: list of rotated streams [B, T, 6].
    Rotation parameters estimated by RotationParamEstimator; parameters accumulated across heads.
    """
    def __init__(self, head_nums=2, fs=50.0, mlp_base=128, mlp_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7,
                 block_specs=None, use_binary=True, pad_mode='SYMMETRIC', name=None):
        super().__init__(name=name)
        if block_specs is None:
            block_specs = BLOCK_SPECS
        self.head_nums = int(head_nums)
        self.estimator = RotationParamEstimator(block_specs=block_specs, fs=fs,
                                                mlp_base=mlp_base, mlp_depth=mlp_depth,
                                                drop=drop, wd=wd,
                                                ln_eps=ln_eps, use_binary=use_binary,
                                                pad_mode=pad_mode,
                                                name='rot_estimator')
        self.eps = 1e-8

    def compute_output_shape(self, input_shape):
        return [tf.TensorShape(input_shape) for _ in range(self.head_nums)]

    def _axis_angle_to_R(self, axis_raw, angle_raw):
        axis = axis_raw / (tf.norm(axis_raw, axis=-1, keepdims=True) + self.eps)
        theta = angle_raw * math.pi
        B = tf.shape(axis)[0]
        ux, uy, uz = axis[:, 0], axis[:, 1], axis[:, 2]
        z = tf.zeros_like(ux)
        K = tf.stack([z, -uz,  uy,
                      uz,  z, -ux,
                     -uy,  ux,  z], axis=-1)
        K = tf.reshape(K, [B, 3, 3])
        I = tf.tile(tf.eye(3, dtype=axis.dtype)[None, ...], [B, 1, 1])
        u = tf.expand_dims(axis, -1)
        uuT = tf.matmul(u, u, transpose_b=True)
        cos = tf.reshape(tf.cos(theta), [-1, 1, 1])
        sin = tf.reshape(tf.sin(theta), [-1, 1, 1])
        R = cos * I + (1.0 - cos) * uuT + sin * K
        return R

    def call(self, x, training=None, **kwargs):  # x: [B, T, 6]
        acc, gyr = x[:, :, :3], x[:, :, 3:6]
        out_list = []
        prev_rot4 = None
        for _ in range(self.head_nums):
            rot4 = self.estimator(x, training=training)  # [B, 4]
            if prev_rot4 is not None:
                rot4 = rot4 + prev_rot4
            prev_rot4 = rot4
            axis  = rot4[:, :3]
            angle = tf.expand_dims(rot4[:, 3], -1)
            R = self._axis_angle_to_R(axis, angle)       # [B, 3, 3]

            acc_t = tf.transpose(acc, [0, 2, 1])
            acc_rot = tf.transpose(tf.matmul(R, acc_t), [0, 2, 1])
            gyr_t = tf.transpose(gyr, [0, 2, 1])
            gyr_rot = tf.transpose(tf.matmul(R, gyr_t), [0, 2, 1])

            out_list.append(tf.concat([acc_rot, gyr_rot], axis=-1))  # [B, T, 6]
        return out_list

# ---------------------------
# 12) Official rTsfNet body
# ---------------------------
def r_tsf_net_official(x_shape, n_classes,
                       learning_rate=1e-3, base_kn=128, depth=3, dropout_rate=0.5,
                       imu_rot_heads=2, fs=50.0, use_orig_input=True,
                       use_binary_selection=True, ln_eps=1e-7, pad_mode='SYMMETRIC'):

    inputs = Input(shape=x_shape[1:])     # [T, 6]
    x = inputs

    rot_layer = Multihead3DRotationOfficial(
        head_nums=imu_rot_heads, fs=fs,
        mlp_base=base_kn, mlp_depth=max(1, depth - 1),
        drop=dropout_rate, wd=WEIGHT_DECAY,
        ln_eps=ln_eps, block_specs=BLOCK_SPECS,
        use_binary=use_binary_selection, pad_mode=pad_mode,
        name='multihead_rot_official'
    )
    rotated_list = rot_layer(x)           # list of [B, T, 6]

    streams = []
    add_l2 = AddL2ChannelsPublic()
    if use_orig_input:
        streams.append(add_l2(x))         # [B, T, 8]
    for xr in rotated_list:
        streams.append(add_l2(xr))
    concat_streams = Concatenate(axis=-1, name='concat_streams')(streams)  # [B, T, 8*(1+heads)]

    feats_all_sets = []

    num_streams = (1 if use_orig_input else 0) + imu_rot_heads
    axis_tags_one_stream = []
    for i in range(8):
        axis_type = i + 1
        sensor_type = 1 if (i <= 2 or i == 6) else 2
        axis_tags_one_stream.append([axis_type, sensor_type])
    axis_tags_one_stream = np.array(axis_tags_one_stream, dtype=np.float32)
    axis_tags_all = np.concatenate(
        [axis_tags_one_stream for _ in range(num_streams)],
        axis=0
    )
    tag_spec_main = {'axis_tags': axis_tags_all}
    tag_dim_main = axis_tags_all.shape[1]

    for spec in BLOCK_SPECS:
        ext = BlockTSFExtractor(num_blocks=spec['num_blocks'], fs=fs,
                                use_time=spec['use_time'], use_freq=spec['use_freq'],
                                tag_spec=tag_spec_main, pad_mode=pad_mode,
                                name=f'main_ext_{spec["name"]}')
        feat_dim = _feat_dim_for_spec(spec['use_time'], spec['use_freq'], tag_dim_main)
        mix = TSFMixerBlock(feat_dim=feat_dim, axis_hidden=base_kn, out_hidden=base_kn,
                            base_depth=max(1, depth - 1), drop=dropout_rate,
                            wd=WEIGHT_DECAY, ln_eps=ln_eps,
                            use_binary=use_binary_selection,
                            name=f'main_mix_{spec["name"]}')
        td  = TimeDistributed(mix, name=f'main_td_{spec["name"]}')
        flt = Flatten(name=f'main_flat_{spec["name"]}')

        tsf_blocks = ext(concat_streams)  # [B, K, A_all, F]
        blk_feat   = td(tsf_blocks)       # [B, K, H]
        blk_feat   = flt(blk_feat)        # [B, K*H]
        feats_all_sets.append(blk_feat)

    z = Concatenate(name='main_concat_sets')(feats_all_sets)
    cls_stack = MLPStack(base_kn=base_kn, depth=depth,
                         drop=dropout_rate, wd=WEIGHT_DECAY,
                         ln_eps=ln_eps, name='cls')
    z = cls_stack(z)
    logits = Dense(n_classes, kernel_regularizer=l2(WEIGHT_DECAY),
                   name='logits')(z)
    probs  = Activation('softmax', dtype='float32', name='softmax')(logits)

    model = Model(inputs, probs, name='rTsfNet_official_aligned')

    opt = Adam(learning_rate=learning_rate, amsgrad=True)
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=opt,
        metrics=['accuracy']
    )
    return model

# ---------------------------
# 13) Build model and compute size
# ---------------------------
x_shape = (1, WINDOW_SAMPLES, 6)

model = r_tsf_net_official(
    x_shape=x_shape,
    n_classes=NUM_CLASSES,
    learning_rate=LR,
    base_kn=MLP_BASE,
    depth=MLP_DEPTH,
    dropout_rate=DROPOUT,
    imu_rot_heads=IMU_ROT_HEADS,
    fs=FS,
    use_orig_input=USE_ORIG_INPUT,
    use_binary_selection=USE_BINARY_SELECTION,
    ln_eps=LN_EPS,
    pad_mode=PAD_MODE
)

print("\n====== Keras model.summary() ======\n")
model.summary(line_length=140)

total_params = model.count_params()
print(f"\nTotal parameters: {total_params:,}")

def fmt_mb(n_bytes: int) -> str:
    return f"{n_bytes / 1024 / 1024:.2f} MB"

bytes_fp32 = total_params * 4
bytes_fp16 = total_params * 2

print("\n====== Model size estimate (parameters only) ======")
print(f"FP32 (float32, 4B/param): {fmt_mb(bytes_fp32)}")
print(f"FP16 (float16, 2B/param): {fmt_mb(bytes_fp16)}")

models_dir = BASE / "models"
models_dir.mkdir(parents=True, exist_ok=True)
tmp_path = models_dir / "rtsfnet_official_dummy.weights.h5"
model.save_weights(tmp_path)
file_bytes = tmp_path.stat().st_size
print(f"\nRandom-initialised weights saved to: {tmp_path.name}")
print(f"Actual .weights.h5 file size:       {fmt_mb(file_bytes)}")

print("\n[Official rTsfNet (IMWUT 2024) – structure & size done]\n")


[Official rTsfNet (IMWUT 2024) – structure & size]
Using defaults: NUM_CLASSES = 8, WINDOW_SAMPLES = 150.

Config for size check:
  NUM_CLASSES         = 8
  WINDOW_SAMPLES      = 150
  IMU_ROT_HEADS       = 2
  MLP_BASE            = 128
  MLP_DEPTH           = 3
  DROPOUT             = 0.5
  WEIGHT_DECAY        = 1e-06
  USE_ORIG_INPUT      = True
  USE_BINARY_SELECTION= True
  PAD_MODE            = SYMMETRIC









Total parameters: 960,698

FP32 (float32, 4B/param): 3.66 MB
FP16 (float16, 2B/param): 1.83 MB

Random-initialised weights saved to: rtsfnet_official_dummy.weights.h5
Actual .weights.h5 file size:       3.98 MB

[Official rTsfNet (IMWUT 2024) – structure & size done]

