In [5]:
# ============ Lightweight rTsfNet (Step 10) – structure & size ============

import json, math
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import (
    Dense, Dropout, LayerNormalization, LeakyReLU,
    Layer, Lambda, Flatten, GlobalAveragePooling1D, Activation
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

print("\n[Lightweight rTsfNet – structure & size]")

# ---------------------------
# 1) Load num_classes and window length if configs exist
# ---------------------------
BASE = Path("/content")
CFG_DIR = BASE / "configs"

if (CFG_DIR / "classes.json").exists():
    with open(CFG_DIR / "classes.json", "r") as f:
        classes_cfg = json.load(f)
    NUM_CLASSES = int(classes_cfg["num_classes"])
    if "window_config" in classes_cfg and "window_samples" in classes_cfg["window_config"]:
        WINDOW_SAMPLES = int(classes_cfg["window_config"]["window_samples"])
    else:
        WINDOW_SAMPLES = 150
    print(f"Detected NUM_CLASSES = {NUM_CLASSES}, WINDOW_SAMPLES = {WINDOW_SAMPLES} from configs.")
else:
    # Adjust these defaults to your actual setup if needed
    NUM_CLASSES = 8
    WINDOW_SAMPLES = 150
    print("Warning: /content/configs/classes.json not found.")
    print(f"Using defaults: NUM_CLASSES = {NUM_CLASSES}, WINDOW_SAMPLES = {WINDOW_SAMPLES}.")

# ---------------------------
# 2) Hyperparameters (must match your training script)
# ---------------------------
FS            = 50.0
IMU_ROT_HEADS = 2
MLP_BASE      = 128
MLP_DEPTH     = 3
DROPOUT       = 0.5
LR            = 1e-3
WEIGHT_DECAY  = 1e-6
USE_ORIG_INPUT = True

print(f"\nConfig for size check:")
print(f"  NUM_CLASSES   = {NUM_CLASSES}")
print(f"  WINDOW_SAMPLES= {WINDOW_SAMPLES}")
print(f"  IMU_ROT_HEADS = {IMU_ROT_HEADS}")
print(f"  MLP_BASE      = {MLP_BASE}")
print(f"  MLP_DEPTH     = {MLP_DEPTH}")
print(f"  DROPOUT       = {DROPOUT}")
print(f"  WEIGHT_DECAY  = {WEIGHT_DECAY}")

# ---------------------------
# 3) Layers (identical to your training code)
# ---------------------------
class TSFFeatureLayer(Layer):
    """
    Input: [B, T, C]  Output: [B, C, F]
    Time domain: mean/std/max/min/ptp/rms/energy/skew/kurt/zcr/ar1/ar2
    Frequency domain: centroid/entropy/flatness/soft-peak frequency + bandpower (0.5–3 / 3–8 / 8–15 Hz)
    """
    def __init__(self, fs=50.0, **kwargs):
        super().__init__(**kwargs)
        self.fs = float(fs)
        self.eps = 1e-8

    def get_config(self):
        cfg = super().get_config()
        cfg.update({'fs': self.fs})
        return cfg

    def call(self, x):  # x: [B, T, C]
        mean = tf.reduce_mean(x, axis=1, keepdims=True)
        std  = tf.math.reduce_std(x, axis=1, keepdims=True) + self.eps

        maxv = tf.reduce_max(x, axis=1, keepdims=True)
        minv = tf.reduce_min(x, axis=1, keepdims=True)
        ptp  = maxv - minv
        rms  = tf.sqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True))
        energy = tf.reduce_sum(tf.square(x), axis=1, keepdims=True)

        skew = tf.reduce_mean(tf.pow((x-mean)/std, 3), axis=1, keepdims=True)
        kurt = tf.reduce_mean(tf.pow((x-mean)/std, 4), axis=1, keepdims=True)

        signs = tf.sign(x)
        sign_changes = tf.abs(signs[:,1:,:] - signs[:,:-1,:])
        zcr = tf.reduce_mean(sign_changes, axis=1, keepdims=True) / 2.0

        x_t1 = x[:,:-1,:]; x_tn1 = x[:,1:,:]
        ar1 = tf.reduce_sum(x_t1*x_tn1, axis=1, keepdims=True) / (
            tf.reduce_sum(tf.square(x_t1), axis=1, keepdims=True) + self.eps
        )

        x_t2 = x[:,:-2,:]; x_tn2 = x[:,2:,:]
        ar2 = tf.reduce_sum(x_t2*x_tn2, axis=1, keepdims=True) / (
            tf.reduce_sum(tf.square(x_t2), axis=1, keepdims=True) + self.eps
        )

        # Frequency domain
        xc = x - mean
        x_bc_t = tf.transpose(xc, [0,2,1])               # [B, C, T]
        fft = tf.signal.rfft(x_bc_t)                     # [B, C, F]
        power = tf.square(tf.abs(fft)) + self.eps        # [B, C, F]
        power = tf.transpose(power, [0,2,1])             # [B, F, C]

        F = tf.shape(power)[1]
        freqs = tf.linspace(0.0, tf.cast(self.fs, tf.float32)/2.0, F)  # [F]
        freqs = tf.reshape(freqs, [1, F, 1])                           # [1, F, 1]

        p = power / (tf.reduce_sum(power, axis=1, keepdims=True) + self.eps)
        centroid = tf.reduce_sum(p * freqs, axis=1, keepdims=True)     # [B, 1, C]
        entropy  = -tf.reduce_sum(p * tf.math.log(p + self.eps), axis=1, keepdims=True) / \
                   (tf.math.log(tf.cast(F, tf.float32) + self.eps))

        geo = tf.exp(tf.reduce_mean(tf.math.log(power), axis=1, keepdims=True))
        ari = tf.reduce_mean(power, axis=1, keepdims=True)
        flatness = geo / (ari + self.eps)

        temp = 10.0
        w = tf.nn.softmax(power * temp, axis=1)                        # [B, F, C]
        soft_peak = tf.reduce_sum(w * freqs, axis=1, keepdims=True)    # [B, 1, C]

        def band(low, high):
            mask = tf.cast((freqs >= low) & (freqs < high), tf.float32)
            bp = tf.reduce_sum(power * mask, axis=1, keepdims=True) / (
                tf.reduce_sum(power, axis=1, keepdims=True) + self.eps
            )
            return bp
        bp1 = band(0.5, 3.0)
        bp2 = band(3.0, 8.0)
        bp3 = band(8.0, 15.0)

        feats = [mean, std, maxv, minv, ptp, rms, energy, skew, kurt, zcr, ar1, ar2,
                 centroid, entropy, flatness, soft_peak, bp1, bp2, bp3]   # each [B,1,C]
        res = tf.concat(feats, axis=1)                                   # [B, Fnum, C]
        return tf.transpose(res, [0,2,1])                                # [B, C, Fnum]


class Multihead3DRotation(Layer):
    """
    Input [B, T, 6] (ACC + GYR), output: a list of length head_nums, each element is [B, T, 6].
    """
    def __init__(self, head_nums=2, base_kn=64, param_depth=2, **kwargs):
        super().__init__(**kwargs)
        self.head_nums = head_nums
        self.base_kn = base_kn
        self.param_depth = param_depth
        self.eps = 1e-8

        self.gap = GlobalAveragePooling1D()
        self.mlp = [Dense(self.base_kn, activation='relu') for _ in range(self.param_depth)]
        self.out_heads = [Dense(4, activation='tanh') for _ in range(self.head_nums)]

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            'head_nums': self.head_nums,
            'base_kn': self.base_kn,
            'param_depth': self.param_depth
        })
        return cfg

    def compute_output_shape(self, input_shape):
        return [tf.TensorShape(input_shape) for _ in range(self.head_nums)]

    def _axis_angle_to_R(self, axis_raw, angle_raw):
        axis = axis_raw / (tf.norm(axis_raw, axis=-1, keepdims=True) + self.eps)
        theta = angle_raw * math.pi                                       # [B,1]
        B = tf.shape(axis)[0]

        ux, uy, uz = axis[:,0], axis[:,1], axis[:,2]
        z = tf.zeros_like(ux)
        K = tf.stack([
            z, -uz,  uy,
            uz,  z, -ux,
           -uy,  ux,  z
        ], axis=-1)
        K = tf.reshape(K, [B,3,3])

        I3 = tf.eye(3, dtype=axis.dtype)
        I  = tf.tile(I3[None, ...], [B,1,1])

        u = tf.expand_dims(axis, -1)
        uuT = tf.matmul(u, u, transpose_b=True)

        cos = tf.reshape(tf.cos(theta), [-1,1,1])
        sin = tf.reshape(tf.sin(theta), [-1,1,1])

        R = cos*I + (1.0 - cos)*uuT + sin*K                               # [B,3,3]
        return R

    def call(self, x):   # x: [B, T, 6]
        acc, gyr = x[:,:,:3], x[:,:,3:6]
        pooled = self.gap(x)                                              # [B, 6]

        h = pooled
        for layer in self.mlp:
            h = layer(h)

        out_list = []
        for oh in self.out_heads:
            p = oh(h)                                                     # [B, 4]
            axis  = p[:,:3]
            angle = tf.expand_dims(p[:,3], -1)                            # [B,1]
            R = self._axis_angle_to_R(axis, angle)                        # [B,3,3]

            acc_t = tf.transpose(acc, [0,2,1])                            # [B,3,T]
            acc_rot_t = tf.matmul(R, acc_t)                               # [B,3,T]
            acc_rot = tf.transpose(acc_rot_t, [0,2,1])                    # [B,T,3]

            gyr_t = tf.transpose(gyr, [0,2,1])                            # [B,3,T]
            gyr_rot_t = tf.matmul(R, gyr_t)                               # [B,3,T]
            gyr_rot = tf.transpose(gyr_rot_t, [0,2,1])                    # [B,T,3]

            out_list.append(tf.concat([acc_rot, gyr_rot], axis=-1))       # [B,T,6]
        return out_list


def add_l2_channels(x):     # x: [B, T, 6]
    acc = x[:,:,:3]
    gyr = x[:,:,3:6]
    l2_acc = tf.sqrt(tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True))
    l2_gyr = tf.sqrt(tf.reduce_sum(tf.square(gyr), axis=-1, keepdims=True))
    return tf.concat([x, l2_acc, l2_gyr], axis=-1)  # [B, T, 8]

# ---------------------------
# 4) rTsfNet main body (identical structure)
# ---------------------------
def r_tsf_net(x_shape,
              n_classes,
              learning_rate=1e-3,
              base_kn=128,
              depth=3,
              dropout_rate=0.5,
              imu_rot_heads=2,
              fs=50.0,
              use_orig_input=True):

    inputs = Input(shape=x_shape[1:])     # [T, 6]
    x = inputs

    rot_layer = Multihead3DRotation(
        head_nums=imu_rot_heads,
        base_kn=64,
        param_depth=2,
        name='multihead_rot'
    )
    rotated_list = rot_layer(x)   # list of [B, T, 6]

    streams = []
    if use_orig_input:
        streams.append(Lambda(add_l2_channels, name='orig_plus_l2')(x))
    for i, xr in enumerate(rotated_list):
        streams.append(Lambda(add_l2_channels, name=f'rot{i}_plus_l2')(xr))

    concat_streams = Lambda(
        lambda lst: tf.concat(lst, axis=-1),
        name='concat_streams'
    )(streams)  # [B, T, 8*(1+heads)]

    tsf = TSFFeatureLayer(fs=fs, name='tsf')(concat_streams)  # [B, C_total, F]

    z = Flatten(name='flatten')(tsf)
    for k in range(depth-1, -1, -1):
        z = Dense(
            MLP_BASE * (2**k),
            kernel_regularizer=l2(WEIGHT_DECAY),
            name=f'fc_{k}'
        )(z)
        z = LayerNormalization(epsilon=1e-7, name=f'ln_{k}')(z)
        z = LeakyReLU(name=f'lrelu_{k}')(z)
        z = Dropout(dropout_rate, name=f'drop_{k}')(z)

    logits = Dense(
        n_classes,
        kernel_regularizer=l2(WEIGHT_DECAY),
        name='logits'
    )(z)
    probs  = Activation('softmax', dtype='float32', name='softmax')(logits)

    model = Model(inputs, probs, name='rTsfNet_officially_aligned_fixed')

    opt = Adam(learning_rate=learning_rate, amsgrad=True)
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=opt,
        metrics=['accuracy']
    )
    return model

# ---------------------------
# 5) Build model and compute size
# ---------------------------
# x_shape in the training script is X_train.shape: (N, T, 6).
# Here we only need a dummy shape with the same last two dims.
x_shape = (1, WINDOW_SAMPLES, 6)

model = r_tsf_net(
    x_shape=x_shape,
    n_classes=NUM_CLASSES,
    learning_rate=LR,
    base_kn=MLP_BASE,
    depth=MLP_DEPTH,
    dropout_rate=DROPOUT,
    imu_rot_heads=IMU_ROT_HEADS,
    fs=FS,
    use_orig_input=USE_ORIG_INPUT
)

print("\n====== Keras model.summary() ======\n")
model.summary(line_length=140)

total_params = model.count_params()
print(f"\nTotal parameters: {total_params:,}")

# Parameter size estimate
def fmt_mb(n_bytes: int) -> str:
    return f"{n_bytes / 1024 / 1024:.2f} MB"

bytes_fp32 = total_params * 4   # float32
bytes_fp16 = total_params * 2   # float16

print("\n====== Model size estimate (parameters only) ======")
print(f"FP32 (float32, 4B/param): {fmt_mb(bytes_fp32)}")
print(f"FP16 (float16, 2B/param): {fmt_mb(bytes_fp16)}")

# Save random-initialised weights to see actual .weights.h5 size
models_dir = BASE / "models"
models_dir.mkdir(parents=True, exist_ok=True)
tmp_path = models_dir / "rtsfnet_light_dummy.weights.h5"
model.save_weights(tmp_path)
file_bytes = tmp_path.stat().st_size
print(f"\nRandom-initialised weights saved to {tmp_path.name}")
print(f"Actual .weights.h5 file size: {fmt_mb(file_bytes)}")

print("\n[Lightweight rTsfNet – structure & size done]\n")


[Lightweight rTsfNet – structure & size]
Using defaults: NUM_CLASSES = 8, WINDOW_SAMPLES = 150.

Config for size check:
  NUM_CLASSES   = 8
  WINDOW_SAMPLES= 150
  IMU_ROT_HEADS = 2
  MLP_BASE      = 128
  MLP_DEPTH     = 3
  DROPOUT       = 0.5
  WEIGHT_DECAY  = 1e-06





Total parameters: 406,160

FP32 (float32, 4B/param): 1.55 MB
FP16 (float16, 2B/param): 0.77 MB

Random-initialised weights saved to rtsfnet_light_dummy.weights.h5
Actual .weights.h5 file size: 1.62 MB

[Lightweight rTsfNet – structure & size done]

