---
Cell 1: Library Imports
---
---

In [1]:
# Cell 1: Library Imports
import os
import argparse
import numpy as np
import pandas as pd
import scipy.io as scio
import time
import torch
torch.cuda.empty_cache()
import torch._dynamo
from torch.utils.data import TensorDataset, DataLoader
# from ptflops import get_model_complexity_info
from sklearn.metrics import classification_report, accuracy_score
from copy import deepcopy
import json
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import torch

---
Cell 2: preset.py
---
---

In [2]:
"""
[file]          preset.py
[description]   default settings of WiFi-based models
"""
minidata_set = 1
preset = {
    # define model
    "model": "THAT",  # "ST-RF", "MLP", "LSTM", "CNN-1D", "CNN-2D", "CLSTM", "ABLSTM", "THAT", "bi-LSTM", "ResNet18"
    # define task
    "task": "activity",  # "identity", "activity", "location", "count"
    # number of repeated experiments
    "repeat": 1,
    # path of data
    "path": {
        "data_x": "/kaggle/input/wimans/wifi_csi/amp",   # directory of CSI amplitude files 
        "data_y": "/kaggle/input/wimans/annotation.csv", # path of annotation file
        "save": "result_lstm_epoch=80_batchsize=32_envs=empty_room_wifiband=2.4.json"               # path to save results
    },
    # data selection for experiments
    "data": {
        "num_users": ["0", "1", "2", "3", "4", "5"],  # select number(s) of users
        "wifi_band": ["2.4"],                         # select WiFi band(s)
        "environment": ["classroom"],                 # select environment(s) ["classroom"], ["meeting_room"], ["empty_room"]
        "length": 3000,                               # default length of CSI
    },
    # hyperparameters of models
    "nn": {
        "lr": 1e-3,           # learning rate
        "epoch": 100,         # number of epochs
        "batch_size": 64,    # batch size
        "threshold": 0.5,     # threshold to binarize sigmoid outputs
    },
    # encoding of activities and locations
    "encoding": {
        "activity": {  # encoding of different activities
            "nan":      [0, 0, 0, 0, 0, 0, 0, 0, 0],
            "nothing":  [1, 0, 0, 0, 0, 0, 0, 0, 0],
            "walk":     [0, 1, 0, 0, 0, 0, 0, 0, 0],
            "rotation": [0, 0, 1, 0, 0, 0, 0, 0, 0],
            "jump":     [0, 0, 0, 1, 0, 0, 0, 0, 0],
            "wave":     [0, 0, 0, 0, 1, 0, 0, 0, 0],
            "lie_down": [0, 0, 0, 0, 0, 1, 0, 0, 0],
            "pick_up":  [0, 0, 0, 0, 0, 0, 1, 0, 0],
            "sit_down": [0, 0, 0, 0, 0, 0, 0, 1, 0],
            "stand_up": [0, 0, 0, 0, 0, 0, 0, 0, 1],
        },
        "location": {  # encoding of different locations
            "nan":  [0, 0, 0, 0, 0],
            "a":    [1, 0, 0, 0, 0],
            "b":    [0, 1, 0, 0, 0],
            "c":    [0, 0, 1, 0, 0],
            "d":    [0, 0, 0, 1, 0],
            "e":    [0, 0, 0, 0, 1],
        },
    },
}


# Few-shot parameters (manually set)
dest_env = "empty_room"       # Destination environment["classroom"], ["meeting_room"], ["empty_room"]
few_shot_epochs = 100         # Number of epochs for few-shot training
few_shot_num_samples = 5     # Number of samples to use from the destination test data

Confusion_matrix = 1

name_run = "few={},{},{},m={},t={},epoch={},batch={},environment={}".format(dest_env, few_shot_epochs, few_shot_num_samples, preset["model"], preset["task"], preset["nn"]["epoch"], preset["nn"]["batch_size"], preset["data"]["environment"])
print(name_run)

few=empty_room,100,5,m=THAT,t=activity,epoch=100,batch=64,environment=['classroom']


---
Cell 3: load_data.py
---
---

---
rpca
---
---

In [3]:
"""
[file]          load_data.py
[description]   load annotation file and CSI amplitude, and encode labels
"""
from sklearn.preprocessing import OneHotEncoder
import numpy as np

# =========================================================
# ÿßŸÜÿ™ÿÆÿßÿ® Ÿàÿ±ŸàÿØ€å ŸÖÿØŸÑ:
# "raw"     : ŸáŸÖŸàŸÜ amp ÿÆÿßŸÖ
# "lowrank" : ŸÖÿ§ŸÑŸÅŸá Low-rank ÿßÿ≤ RPCA (L)
# "sparse"  : ŸÖÿ§ŸÑŸÅŸá Sparse ÿßÿ≤ RPCA (S)
CSI_INPUT_MODE = "lowrank"   # <-- raw / lowrank / sparse

# RPCA (IALM) ÿ™ŸÜÿ∏€åŸÖÿßÿ™ ÿ≥ÿ±€åÿπ‚Äåÿ™ÿ±
RPCA_MAX_ITER = 80
RPCA_TOL      = 1e-5
RPCA_RHO      = 1.5
RPCA_MU_INIT  = None
RPCA_LAMBDA   = None     # None -> 1/sqrt(max(m,n))

# Cache (ÿÆ€åŸÑ€å ŸÖŸáŸÖ ÿ®ÿ±ÿß€å ÿ≥ÿ±ÿπÿ™)
CACHE_ENABLED = True
CACHE_ROOT    = "/kaggle/working/csi_cache"  # ÿÆÿ±Ÿàÿ¨€å‚ÄåŸáÿß ÿß€åŸÜÿ¨ÿß ÿ∞ÿÆ€åÿ±Ÿá ŸÖ€åÿ¥ŸÜ
# =========================================================


def _soft_threshold(X, tau):
    return np.sign(X) * np.maximum(np.abs(X) - tau, 0.0)

def _svt(X, tau):
    U, s, Vt = np.linalg.svd(X, full_matrices=False)
    s = np.maximum(s - tau, 0.0)
    if np.all(s == 0):
        return np.zeros_like(X)
    return (U * s) @ Vt

def _rpca_ialm(M, lam=None, mu=None, rho=1.5, max_iter=80, tol=1e-5):
    M = M.astype(np.float64, copy=False)
    m, n = M.shape

    if lam is None:
        lam = 1.0 / np.sqrt(max(m, n))

    if mu is None:
        s0 = np.linalg.svd(M, compute_uv=False, full_matrices=False)[0] if M.size else 1.0
        mu = 1.25 / (s0 + 1e-12)

    L = np.zeros_like(M)
    S = np.zeros_like(M)
    Y = np.zeros_like(M)

    normM = np.linalg.norm(M, ord="fro") + 1e-12

    for _ in range(max_iter):
        L = _svt(M - S + (1.0 / mu) * Y, 1.0 / mu)
        S = _soft_threshold(M - L + (1.0 / mu) * Y, lam / mu)

        R = M - L - S
        Y = Y + mu * R

        if (np.linalg.norm(R, ord="fro") / normM) < tol:
            break
        mu *= rho

    return L.astype(np.float32), S.astype(np.float32)


def _rpca_keep_shape(X):
    """RPCA ÿ±Ÿà€å (T,F) Ÿà ÿ®ÿßÿ≤⁄Øÿ±ÿØÿßŸÜÿØŸÜ ÿØŸÇ€åŸÇÿßŸã ÿ®Ÿá shape ÿßŸàŸÑ€åŸá"""
    X = np.asarray(X, dtype=np.float32)
    if X.ndim == 1:
        M = X[:, None]
        L, S = _rpca_ialm(M, RPCA_LAMBDA, RPCA_MU_INIT, RPCA_RHO, RPCA_MAX_ITER, RPCA_TOL)
        return L[:, 0], S[:, 0]

    T = X.shape[0]
    F = int(np.prod(X.shape[1:]))
    M = X.reshape(T, F)

    L, S = _rpca_ialm(M, RPCA_LAMBDA, RPCA_MU_INIT, RPCA_RHO, RPCA_MAX_ITER, RPCA_TOL)
    return L.reshape(X.shape), S.reshape(X.shape)


def _cache_path(label, mode):
    # mode: "lowrank" or "sparse"
    # ŸÅÿß€åŸÑ ÿÆÿ±Ÿàÿ¨€å: /kaggle/working/csi_cache/lowrank/<label>.npy
    return os.path.join(CACHE_ROOT, mode, f"{label}.npy")


def _apply_mode_with_cache(data_csi, label):
    if CSI_INPUT_MODE == "raw":
        return data_csi.astype(np.float32, copy=False)

    # lowrank €åÿß sparse
    mode = CSI_INPUT_MODE
    if CACHE_ENABLED:
        os.makedirs(os.path.join(CACHE_ROOT, mode), exist_ok=True)
        p = _cache_path(label, mode)
        if os.path.exists(p):
            return np.load(p).astype(np.float32, copy=False)

    L, S = _rpca_keep_shape(data_csi)
    out = L if mode == "lowrank" else S

    if CACHE_ENABLED:
        np.save(p, out.astype(np.float32))

    return out.astype(np.float32, copy=False)


def load_data_y(var_path_data_y,
                var_environment=None, 
                var_wifi_band=None, 
                var_num_users=None):
    data_pd_y = pd.read_csv(var_path_data_y, dtype=str)
    if var_environment is not None:
        data_pd_y = data_pd_y[data_pd_y["environment"].isin(var_environment)]
    if var_wifi_band is not None:
        data_pd_y = data_pd_y[data_pd_y["wifi_band"].isin(var_wifi_band)]
    if var_num_users is not None:
        data_pd_y = data_pd_y[data_pd_y["number_of_users"].isin(var_num_users)]
    return data_pd_y


def load_data_x(var_path_data_x, var_label_list):
    var_path_list = [os.path.join(var_path_data_x, var_label + ".npy") for var_label in var_label_list]
    data_x = []
    for var_label, var_path in zip(var_label_list, var_path_list):
        data_csi = np.load(var_path)

        # ‚úÖ RPCA lowrank/sparse ÿ®ÿØŸàŸÜ ÿ™ÿ∫€å€åÿ± shape + ÿ®ÿß cache
        data_csi = _apply_mode_with_cache(data_csi, var_label)

        var_pad_length = preset["data"]["length"] - data_csi.shape[0]
        data_csi_pad = np.pad(data_csi, ((var_pad_length, 0), (0, 0), (0, 0), (0, 0)))
        data_x.append(data_csi_pad)

    return np.array(data_x)


def encode_data_y(data_pd_y, var_task):
    if var_task == "identity":
        data_y = encode_identity(data_pd_y)
    elif var_task == "activity":
        data_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
    elif var_task == "location":
        data_y = encode_location(data_pd_y, preset["encoding"]["location"])
    elif var_task == "count":
        data_y = encode_count(data_pd_y, preset["encoding"]["location"])
    return data_y


def encode_identity(data_pd_y):
    data_location_pd_y = data_pd_y[["user_1_location", "user_2_location",
                                    "user_3_location", "user_4_location",
                                    "user_5_location", "user_6_location"]]
    data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
    data_identity_y[data_identity_y != "nan"] = 1
    data_identity_y[data_identity_y == "nan"] = 0
    return data_identity_y.astype("int8")


def encode_activity(data_pd_y, var_encoding):
    data_activity_pd_y = data_pd_y[["user_1_activity", "user_2_activity",
                                    "user_3_activity", "user_4_activity",
                                    "user_5_activity", "user_6_activity"]]
    data_activity_y = data_activity_pd_y.to_numpy(copy=True).astype(str)
    return np.array([[var_encoding[v] for v in sample] for sample in data_activity_y])


def encode_location(data_pd_y, var_encoding):
    data_location_pd_y = data_pd_y[["user_1_location", "user_2_location",
                                    "user_3_location", "user_4_location",
                                    "user_5_location", "user_6_location"]]
    data_location_y = data_location_pd_y.to_numpy(copy=True).astype(str)
    return np.array([[var_encoding[v] for v in sample] for sample in data_location_y])


def encode_count(data_pd_y, var_encoding):
    data_location_pd_y = data_pd_y[["user_1_location", "user_2_location",
                                    "user_3_location", "user_4_location",
                                    "user_5_location", "user_6_location"]]
    data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
    data_identity_y[data_identity_y != "nan"] = 1
    data_identity_y[data_identity_y == "nan"] = 0
    data_identity_onehot_y = data_identity_y.astype("int8")

    count_data = np.sum(data_identity_onehot_y, axis=1).reshape(-1, 1)
    encoder = OneHotEncoder(sparse=False)
    return encoder.fit_transform(count_data).astype("int8")


---
svd
---
---

In [4]:
# """
# [file]          load_data.py
# [description]   load annotation file and CSI amplitude, and encode labels
# """
# from sklearn.preprocessing import OneHotEncoder
# import numpy as np

# # Note: All necessary libraries (os, numpy, pandas, etc.) are imported in Cell 1.
# # from preset import preset   --> preset is already defined in Cell 2.

# # =========================================================
# # üîß NEW: Choose CSI representation mode here (ONLY EDIT THIS)
# # ---------------------------------------------------------
# # "raw"     : use original CSI amplitude as-is
# # "lowrank" : use low-rank approximation (SVD)
# # "sparse"  : keep only large-magnitude entries (dense array with many zeros)
# CSI_INPUT_MODE = "sparse"     # <-- set to: "raw" / "lowrank" / "sparse"

# # Low-rank settings
# LOW_RANK_ENERGY = 0.95     # keep enough singular values to preserve this energy
# LOW_RANK_RANK   = None     # if set to an int (e.g., 10), it overrides ENERGY

# # Sparse settings
# SPARSE_KEEP_RATIO = 0.10   # keep top 10% magnitudes (globally per sample)
# SPARSE_MIN_ABS    = None   # if set (e.g., 0.5), keeps |x|>=threshold instead of keep_ratio
# # =========================================================


# def _low_rank_approx_keep_shape(X, rank=None, energy=0.95):
#     """
#     Low-rank approximation using SVD while preserving the original shape.
#     Works for 1D/2D/ND by flattening all non-time dims into features.
#     Assumes first axis is time.
#     """
#     X = np.asarray(X, dtype=np.float32)

#     if X.ndim == 1:
#         M = X[:, None]  # (T,1)
#         U, S, Vt = np.linalg.svd(M, full_matrices=False)
#         if rank is None:
#             s2 = S**2
#             cum = np.cumsum(s2) / (np.sum(s2) + 1e-12)
#             rank = int(np.searchsorted(cum, energy) + 1)
#         rank = max(1, min(rank, S.shape[0]))
#         M_lr = (U[:, :rank] * S[:rank]) @ Vt[:rank, :]
#         return M_lr[:, 0].astype(np.float32)

#     # ND: reshape to (T, F)
#     T = X.shape[0]
#     F = int(np.prod(X.shape[1:]))
#     M = X.reshape(T, F)

#     U, S, Vt = np.linalg.svd(M, full_matrices=False)

#     if rank is None:
#         s2 = S**2
#         cum = np.cumsum(s2) / (np.sum(s2) + 1e-12)
#         rank = int(np.searchsorted(cum, energy) + 1)

#     rank = max(1, min(rank, S.shape[0]))
#     M_lr = (U[:, :rank] * S[:rank]) @ Vt[:rank, :]

#     return M_lr.reshape(X.shape).astype(np.float32)


# def _to_sparse_dense_keep_shape(X, keep_ratio=0.10, min_abs=None):
#     """
#     Makes X sparse-in-content (many zeros) but keeps it as a dense numpy array
#     so the rest of the pipeline (np.save/np.load/pad/model) doesn't change.
#     Keeps the same shape.
#     """
#     X = np.asarray(X, dtype=np.float32)
#     flat = X.ravel()
#     if flat.size == 0:
#         return X.astype(np.float32)

#     absflat = np.abs(flat)

#     if min_abs is not None:
#         thr = float(min_abs)
#         mask = absflat >= thr
#     else:
#         k = int(np.ceil(keep_ratio * flat.size))
#         k = max(1, min(k, flat.size))
#         if k == flat.size:
#             mask = np.ones_like(absflat, dtype=bool)
#         else:
#             thr = np.partition(absflat, -k)[-k]
#             mask = absflat >= thr

#     out = np.zeros_like(flat, dtype=np.float32)
#     out[mask] = flat[mask]
#     return out.reshape(X.shape).astype(np.float32)


# def _apply_csi_mode(data_csi):
#     """
#     Apply selected CSI_INPUT_MODE to a single sample array.
#     """
#     if CSI_INPUT_MODE == "raw":
#         return data_csi.astype(np.float32, copy=False)

#     elif CSI_INPUT_MODE == "lowrank":
#         return _low_rank_approx_keep_shape(
#             data_csi,
#             rank=LOW_RANK_RANK,
#             energy=LOW_RANK_ENERGY
#         )

#     elif CSI_INPUT_MODE == "sparse":
#         return _to_sparse_dense_keep_shape(
#             data_csi,
#             keep_ratio=SPARSE_KEEP_RATIO,
#             min_abs=SPARSE_MIN_ABS
#         )

#     else:
#         raise ValueError(f"Unknown CSI_INPUT_MODE: {CSI_INPUT_MODE}. Use 'raw', 'lowrank', or 'sparse'.")


# def load_data_y(var_path_data_y,
#                 var_environment=None, 
#                 var_wifi_band=None, 
#                 var_num_users=None):
#     """
#     Load annotation file (*.csv) as a pandas dataframe and filter by environment, WiFi band, and number of users.
#     """
#     data_pd_y = pd.read_csv(var_path_data_y, dtype=str)
#     if var_environment is not None:
#         data_pd_y = data_pd_y[data_pd_y["environment"].isin(var_environment)]
#     if var_wifi_band is not None:
#         data_pd_y = data_pd_y[data_pd_y["wifi_band"].isin(var_wifi_band)]
#     if var_num_users is not None:
#         data_pd_y = data_pd_y[data_pd_y["number_of_users"].isin(var_num_users)]
#     return data_pd_y


# def load_data_x(var_path_data_x, var_label_list):
#     """
#     Load CSI amplitude (*.npy) files based on a label list.
#     """
#     var_path_list = [os.path.join(var_path_data_x, var_label + ".npy") for var_label in var_label_list]
#     data_x = []
#     for var_path in var_path_list:
#         data_csi = np.load(var_path)

#         # ‚úÖ NEW: convert input CSI according to selected mode (raw/lowrank/sparse)
#         data_csi = _apply_csi_mode(data_csi)

#         var_pad_length = preset["data"]["length"] - data_csi.shape[0]
#         data_csi_pad = np.pad(data_csi, ((var_pad_length, 0), (0, 0), (0, 0), (0, 0)))
#         data_x.append(data_csi_pad)
#     data_x = np.array(data_x)
#     return data_x


# def encode_data_y(data_pd_y, var_task):
#     """
#     Encode labels according to specific task.
#     """
#     if var_task == "identity":
#         data_y = encode_identity(data_pd_y)
#     elif var_task == "activity":
#         data_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
#     elif var_task == "location":
#         data_y = encode_location(data_pd_y, preset["encoding"]["location"])
#     elif var_task == "count":
#         data_y = encode_count(data_pd_y, preset["encoding"]["location"])
#     return data_y


# def encode_identity(data_pd_y):
#     """
#     Onehot encoding for identity labels.
#     """
#     data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
#                                     "user_3_location", "user_4_location", 
#                                     "user_5_location", "user_6_location"]]
#     data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
#     data_identity_y[data_identity_y != "nan"] = 1
#     data_identity_y[data_identity_y == "nan"] = 0
#     data_identity_onehot_y = data_identity_y.astype("int8")
#     return data_identity_onehot_y


# def encode_activity(data_pd_y, var_encoding):
#     """
#     Onehot encoding for activity labels.
#     """
#     data_activity_pd_y = data_pd_y[["user_1_activity", "user_2_activity", 
#                                     "user_3_activity", "user_4_activity", 
#                                     "user_5_activity", "user_6_activity"]]
#     data_activity_y = data_activity_pd_y.to_numpy(copy=True).astype(str)
#     data_activity_onehot_y = np.array([[var_encoding[var_y] for var_y in var_sample] for var_sample in data_activity_y])
#     return data_activity_onehot_y


# def encode_location(data_pd_y, var_encoding):
#     """
#     Onehot encoding for location labels.
#     """
#     data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
#                                     "user_3_location", "user_4_location", 
#                                     "user_5_location", "user_6_location"]]
#     data_location_y = data_location_pd_y.to_numpy(copy=True).astype(str)
#     data_location_onehot_y = np.array([[var_encoding[var_y] for var_y in var_sample] for var_sample in data_location_y])
#     return data_location_onehot_y


# def encode_count(data_pd_y, var_encoding):
#     """
#     Onehot encoding for identity labels.
#     """
#     data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
#                                     "user_3_location", "user_4_location", 
#                                     "user_5_location", "user_6_location"]]
#     data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
#     data_identity_y[data_identity_y != "nan"] = 1
#     data_identity_y[data_identity_y == "nan"] = 0
#     data_identity_onehot_y = data_identity_y.astype("int8")
#     print("data_identity_onehot_y",data_identity_onehot_y.shape)
#     count_data = np.sum(data_identity_onehot_y, axis=1)
#     print("count_data",count_data.shape)
#     count_data = count_data.reshape(-1, 1)  # shape = (11286, 1)
#     encoder = OneHotEncoder(sparse=False)  
#     count_data_onehot = encoder.fit_transform(count_data)
#     print(count_data_onehot.shape)  
#     count_data_onehot = count_data_onehot.astype("int8")

#     return count_data_onehot


# # Test functions (optional)
# def test_load_data_y():
#     print(load_data_y(preset["path"]["data_y"], var_environment=["classroom"]).describe())
#     print(load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"]).describe())
#     print(load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"], var_num_users=["1", "2", "3"]).describe())

# def test_load_data_x():
#     data_pd_y = load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"], var_num_users=None)
#     var_label_list = data_pd_y["label"].to_list()
#     data_x = load_data_x(preset["path"]["data_x"], var_label_list)
#     print(data_x.shape)

# def test_encode_identity():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_identity_onehot_y = encode_identity(data_pd_y)
#     print(data_identity_onehot_y.shape)
#     print(data_identity_onehot_y[2000])

# def test_encode_activity():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_activity_onehot_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
#     print(data_activity_onehot_y.shape)
#     print(data_activity_onehot_y[1560])

# def test_encode_location():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_location_onehot_y = encode_location(data_pd_y, preset["encoding"]["location"])
#     print(data_location_onehot_y.shape)
#     print(data_location_onehot_y[1560])

# def test_encode_count():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_count_onehot_y = encode_count(data_pd_y, preset["encoding"]["location"])
#     print(data_count_onehot_y.shape)
#     print(data_count_onehot_y[20])


In [5]:
# """
# [file]          load_data.py
# [description]   load annotation file and CSI amplitude, and encode labels
# """
# from sklearn.preprocessing import OneHotEncoder
# import numpy as np

# # Note: All necessary libraries (os, numpy, pandas, etc.) are imported in Cell 1.
# # from preset import preset   --> preset is already defined in Cell 2.

# def load_data_y(var_path_data_y,
#                 var_environment=None, 
#                 var_wifi_band=None, 
#                 var_num_users=None):
#     """
#     Load annotation file (*.csv) as a pandas dataframe and filter by environment, WiFi band, and number of users.
#     """
#     data_pd_y = pd.read_csv(var_path_data_y, dtype=str)
#     if var_environment is not None:
#         data_pd_y = data_pd_y[data_pd_y["environment"].isin(var_environment)]
#     if var_wifi_band is not None:
#         data_pd_y = data_pd_y[data_pd_y["wifi_band"].isin(var_wifi_band)]
#     if var_num_users is not None:
#         data_pd_y = data_pd_y[data_pd_y["number_of_users"].isin(var_num_users)]
#     return data_pd_y

# def load_data_x(var_path_data_x, var_label_list):
#     """
#     Load CSI amplitude (*.npy) files based on a label list.
#     """
#     var_path_list = [os.path.join(var_path_data_x, var_label + ".npy") for var_label in var_label_list]
#     data_x = []
#     for var_path in var_path_list:
#         data_csi = np.load(var_path)
#         var_pad_length = preset["data"]["length"] - data_csi.shape[0]
#         data_csi_pad = np.pad(data_csi, ((var_pad_length, 0), (0, 0), (0, 0), (0, 0)))
#         data_x.append(data_csi_pad)
#     data_x = np.array(data_x)
#     return data_x

# def encode_data_y(data_pd_y, var_task):
#     """
#     Encode labels according to specific task.
#     """
#     if var_task == "identity":
#         data_y = encode_identity(data_pd_y)
#     elif var_task == "activity":
#         data_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
#     elif var_task == "location":
#         data_y = encode_location(data_pd_y, preset["encoding"]["location"])
#     elif var_task == "count":
#         data_y = encode_count(data_pd_y, preset["encoding"]["location"])
#     return data_y

# def encode_identity(data_pd_y):
#     """
#     Onehot encoding for identity labels.
#     """
#     data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
#                                     "user_3_location", "user_4_location", 
#                                     "user_5_location", "user_6_location"]]
#     data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
#     data_identity_y[data_identity_y != "nan"] = 1
#     data_identity_y[data_identity_y == "nan"] = 0
#     data_identity_onehot_y = data_identity_y.astype("int8")
#     return data_identity_onehot_y



# def encode_activity(data_pd_y, var_encoding):
#     """
#     Onehot encoding for activity labels.
#     """
#     data_activity_pd_y = data_pd_y[["user_1_activity", "user_2_activity", 
#                                     "user_3_activity", "user_4_activity", 
#                                     "user_5_activity", "user_6_activity"]]
#     data_activity_y = data_activity_pd_y.to_numpy(copy=True).astype(str)
#     data_activity_onehot_y = np.array([[var_encoding[var_y] for var_y in var_sample] for var_sample in data_activity_y])
#     return data_activity_onehot_y

# def encode_location(data_pd_y, var_encoding):
#     """
#     Onehot encoding for location labels.
#     """
#     data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
#                                     "user_3_location", "user_4_location", 
#                                     "user_5_location", "user_6_location"]]
#     data_location_y = data_location_pd_y.to_numpy(copy=True).astype(str)
#     data_location_onehot_y = np.array([[var_encoding[var_y] for var_y in var_sample] for var_sample in data_location_y])
#     return data_location_onehot_y

# # Test functions (optional)
# def test_load_data_y():
#     print(load_data_y(preset["path"]["data_y"], var_environment=["classroom"]).describe())
#     print(load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"]).describe())
#     print(load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"], var_num_users=["1", "2", "3"]).describe())

# def test_load_data_x():
#     data_pd_y = load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"], var_num_users=None)
#     var_label_list = data_pd_y["label"].to_list()
#     data_x = load_data_x(preset["path"]["data_x"], var_label_list)
#     print(data_x.shape)

# def test_encode_identity():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_identity_onehot_y = encode_identity(data_pd_y)
#     print(data_identity_onehot_y.shape)
#     print(data_identity_onehot_y[2000])

# def test_encode_activity():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_activity_onehot_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
#     print(data_activity_onehot_y.shape)
#     print(data_activity_onehot_y[1560])

# def test_encode_location():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_location_onehot_y = encode_location(data_pd_y, preset["encoding"]["location"])
#     print(data_location_onehot_y.shape)
#     print(data_location_onehot_y[1560])

# def encode_count(data_pd_y, var_encoding):
#     """
#     Onehot encoding for identity labels.
#     """
#     data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
#                                     "user_3_location", "user_4_location", 
#                                     "user_5_location", "user_6_location"]]
#     data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
#     data_identity_y[data_identity_y != "nan"] = 1
#     data_identity_y[data_identity_y == "nan"] = 0
#     data_identity_onehot_y = data_identity_y.astype("int8")
#     print("data_identity_onehot_y",data_identity_onehot_y.shape)
#     count_data = np.sum(data_identity_onehot_y, axis=1)
#     print("count_data",count_data.shape)
#     count_data = count_data.reshape(-1, 1)  # shape = (11286, 1)
#     encoder = OneHotEncoder(sparse=False)  
#     count_data_onehot = encoder.fit_transform(count_data)
#     print(count_data_onehot.shape)  
#     count_data_onehot = count_data_onehot.astype("int8")

#     return count_data_onehot


# def test_encode_count():
#     data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
#     data_count_onehot_y = encode_count(data_pd_y)
#     print(data_count_onehot_y.shape)
#     print(data_count_onehot_y[20])

# # if __name__ == "__main__":
# #     test_encode_count()
# #     test_load_data_y()
# #     test_load_data_x()
# #     test_encode_identity()
# #     test_encode_activity()
# #     test_encode_location()


---
Cell 4: preprocess.py
---
---

In [6]:
"""
[file]          preprocess.py
[description]   preprocess WiFi CSI data
"""

# All necessary libraries are already imported in Cell 1.

# def mat_to_amp(data_mat):
#     """
#     Calculate amplitude of raw WiFi CSI data.
#     """
#     var_length = data_mat["trace"].shape[0]
#     data_csi_amp = [abs(data_mat["trace"][var_t][0][0][0][-1]) for var_t in range(var_length)]
#     data_csi_amp = np.array(data_csi_amp, dtype=np.float32)
#     return data_csi_amp

def extract_csi_amp(var_dir_mat, var_dir_amp):
    """
    Read raw WiFi CSI (*.mat) files, calculate CSI amplitude, and save as (*.npy).
    """
    var_path_mat = os.listdir(var_dir_mat)
    for var_c, var_path in enumerate(var_path_mat):
        data_mat = scio.loadmat(os.path.join(var_dir_mat, var_path))
        data_csi_amp = mat_to_amp(data_mat)
        # print(var_c, data_csi_amp.shape)
        var_path_save = os.path.join(var_dir_amp, var_path.replace(".mat", ".npy"))
        with open(var_path_save, "wb") as var_file:
            np.save(var_file, data_csi_amp)



# # ÿ™ŸÜÿ∏€åŸÖÿßÿ™ low-rank (ÿ®ÿØŸàŸÜ ÿ™ÿ∫€å€åÿ± Ÿàÿ±ŸàÿØ€å mat_to_amp)
# LOW_RANK_ENERGY = 0.95   # ŸÖÿ´ŸÑÿßŸã 95% ÿßŸÜÿ±⁄ò€å
# LOW_RANK_RANK = None     # ÿß⁄Øÿ± ÿπÿØÿØ ÿ®ÿ∞ÿßÿ±€å (ŸÖÿ´ŸÑÿßŸã 5)ÿå ÿ®Ÿá ÿ¨ÿß€å ENERGY ÿßÿ≤ rank ÿ´ÿßÿ®ÿ™ ÿßÿ≥ÿ™ŸÅÿßÿØŸá ŸÖ€åÿ¥Ÿá

# def _low_rank_approx(X, rank=None, energy=0.95):
#     X = np.asarray(X)

#     was_1d = (X.ndim == 1)
#     if was_1d:
#         X = X[:, None]

#     U, S, Vt = np.linalg.svd(X, full_matrices=False)

#     if rank is None:
#         s2 = S**2
#         cum = np.cumsum(s2) / (np.sum(s2) + 1e-12)
#         rank = int(np.searchsorted(cum, energy) + 1)

#     rank = max(1, min(rank, S.shape[0]))
#     X_lr = (U[:, :rank] * S[:rank]) @ Vt[:rank, :]

#     if was_1d:
#         X_lr = X_lr[:, 0]

#     return X_lr.astype(np.float32)

# def mat_to_amp(data_mat):
#     """
#     Calculate amplitude of raw WiFi CSI data, then return its low-rank approximation.
#     (Ÿàÿ±ŸàÿØ€å ÿ™ÿßÿ®ÿπ ÿ™ÿ∫€å€åÿ± ŸÜ⁄©ÿ±ÿØŸá)
#     """
#     var_length = data_mat["trace"].shape[0]
#     data_csi_amp = [abs(data_mat["trace"][var_t][0][0][0][-1]) for var_t in range(var_length)]
#     data_csi_amp = np.array(data_csi_amp, dtype=np.float32)

#     # ÿÆÿ±Ÿàÿ¨€å low-rank ÿ®ÿß ŸáŸÖÿßŸÜ ÿßÿ®ÿπÿßÿØ
#     data_csi_amp_lr = _low_rank_approx(
#         data_csi_amp,
#         rank=LOW_RANK_RANK,
#         energy=LOW_RANK_ENERGY
#     )
#     return data_csi_amp_lr



# # ÿ™ŸÜÿ∏€åŸÖÿßÿ™ sparsity (ÿ®ÿØŸàŸÜ ÿ™ÿ∫€å€åÿ± Ÿàÿ±ŸàÿØ€å mat_to_amp)
# SPARSE_KEEP_RATIO = 0.10   # ŸÖÿ´ŸÑÿß ŸÅŸÇÿ∑ 10% ÿ®ÿ≤ÿ±⁄Øÿ™ÿ±€åŸÜ ŸÖŸÇÿßÿØ€åÿ± ŸÜ⁄ØŸá ÿØÿßÿ¥ÿ™Ÿá ÿ®ÿ¥ŸÜ
# SPARSE_MIN_ABS = None      # ÿß⁄Øÿ± ÿπÿØÿØ ÿ®ÿ∞ÿßÿ±€å (ŸÖÿ´ŸÑÿß 0.5)ÿå ÿ®Ÿá ÿ¨ÿß€å keep_ratio ÿ¢ÿ≥ÿ™ÿßŸÜŸá ÿ´ÿßÿ®ÿ™ ŸÖ€åÿ¥Ÿá

# def _to_sparse(X, keep_ratio=0.10, min_abs=None):
#     """
#     Convert X to a sparse representation by keeping only large-magnitude entries.
#     Returns:
#       - scipy.sparse.csr_matrix if SciPy is available
#       - otherwise returns a dense array with many zeros (still "sparse" in content)
#     """
#     X = np.asarray(X)
#     flat = X.ravel()
#     absflat = np.abs(flat)

#     if flat.size == 0:
#         return X.astype(np.float32)

#     # ÿßŸÜÿ™ÿÆÿßÿ® ÿ¢ÿ≥ÿ™ÿßŸÜŸá
#     if min_abs is not None:
#         thr = float(min_abs)
#         mask = absflat >= thr
#     else:
#         k = int(np.ceil(keep_ratio * flat.size))
#         k = max(1, min(k, flat.size))
#         if k == flat.size:
#             mask = np.ones_like(absflat, dtype=bool)
#         else:
#             thr = np.partition(absflat, -k)[-k]  # kth largest magnitude
#             mask = absflat >= thr

#     idx = np.nonzero(mask)[0]
#     data = flat[idx].astype(np.float32)

#     # ÿß⁄Øÿ± SciPy Ÿáÿ≥ÿ™: sparse ŸàÿßŸÇÿπ€å ÿ®ÿ≥ÿßÿ≤
#     try:
#         # ŸÖÿπŸÖŸàŸÑÿßŸã ÿ™Ÿà Cell1 €åÿß ÿßÿ≤ ŸÇÿ®ŸÑ import ÿ¥ÿØŸáÿõ ÿß⁄Øÿ± ŸáŸÖ ŸÜÿ¥ÿØŸá ÿ®ÿßÿ¥Ÿá ÿß€åŸÜÿ¨ÿß ÿ™ŸÑÿßÿ¥ ŸÖ€å‚Äå⁄©ŸÜŸá.
#         import scipy.sparse as sp

#         if X.ndim == 1:
#             rows = idx
#             cols = np.zeros_like(rows)
#             shape = (X.shape[0], 1)
#         else:
#             rows, cols = np.unravel_index(idx, X.shape)
#             shape = X.shape

#         return sp.coo_matrix((data, (rows, cols)), shape=shape).tocsr()

#     except Exception:
#         # fallback: ÿ¢ÿ±ÿß€åŸá‚Äå€å dense ÿ®ÿß ÿµŸÅÿ±Ÿáÿß€å ÿ≤€åÿßÿØ
#         out = np.zeros_like(flat, dtype=np.float32)
#         out[idx] = data
#         return out.reshape(X.shape)

def mat_to_amp(data_mat):
    """
    Calculate amplitude of raw WiFi CSI data, then return its sparse version.
    (Ÿàÿ±ŸàÿØ€å ÿ™ÿßÿ®ÿπ ÿ™ÿ∫€å€åÿ± ŸÜ⁄©ÿ±ÿØŸá)
    """
    var_length = data_mat["trace"].shape[0]
    data_csi_amp = [abs(data_mat["trace"][var_t][0][0][0][-1]) for var_t in range(var_length)]
    data_csi_amp = np.array(data_csi_amp, dtype=np.float32)

    # ÿÆÿ±Ÿàÿ¨€å sparse (CSR ÿß⁄Øÿ± SciPy ÿ®ÿßÿ¥ÿØ)
    return _to_sparse(data_csi_amp, keep_ratio=SPARSE_KEEP_RATIO, min_abs=SPARSE_MIN_ABS)

# ÿ™ŸÜÿ∏€åŸÖÿßÿ™ RPCA (ŸÖ€å‚Äåÿ™ŸàŸÜ€å ÿπŸàÿ∂ÿ¥ŸàŸÜ ⁄©ŸÜ€å)
RPCA_MAX_ITER = 500
RPCA_TOL = 1e-7
RPCA_RHO = 1.5
RPCA_MU_INIT = None     # None €åÿπŸÜ€å ÿÆŸàÿØ⁄©ÿßÿ±
RPCA_LAMBDA = None      # None €åÿπŸÜ€å 1/sqrt(max(m,n))

def _soft_threshold(X, tau):
    return np.sign(X) * np.maximum(np.abs(X) - tau, 0.0)

def _svt(X, tau):
    # Singular Value Thresholding
    U, s, Vt = np.linalg.svd(X, full_matrices=False)
    s_thr = np.maximum(s - tau, 0.0)
    # ÿß⁄Øÿ± ŸáŸÖŸá ÿµŸÅÿ± ÿ¥ÿØÿå ÿ≥ÿ±€åÿπ ÿ®ÿ±⁄Øÿ±ÿØ
    if np.all(s_thr == 0):
        return np.zeros_like(X)
    return (U * s_thr) @ Vt

def _rpca_ialm(M, lam=None, mu=None, rho=1.5, max_iter=500, tol=1e-7):
    """
    Robust PCA via Inexact Augmented Lagrange Multiplier (IALM)
    Decompose: M = L + S
    Returns: L, S (same shape as M)
    """
    M = M.astype(np.float64, copy=False)
    m, n = M.shape

    if lam is None:
        lam = 1.0 / np.sqrt(max(m, n))

    # mu Ÿæ€åÿ¥ŸÜŸáÿßÿØ€å (ÿÆŸàÿØ⁄©ÿßÿ±)
    if mu is None:
        # ||M||_2 ÿ™ŸÇÿ±€åÿ®ÿßŸã ÿ®ÿ≤ÿ±⁄Ø‚Äåÿ™ÿ±€åŸÜ singular value ÿßÿ≥ÿ™
        norm2 = np.linalg.svd(M, compute_uv=False)[0] if M.size else 1.0
        mu = 1.25 / (norm2 + 1e-12)

    L = np.zeros_like(M)
    S = np.zeros_like(M)
    Y = np.zeros_like(M)

    normM = np.linalg.norm(M, ord='fro') + 1e-12

    for _ in range(max_iter):
        # L update
        L = _svt(M - S + (1.0/mu)*Y, 1.0/mu)

        # S update (sparse)
        S = _soft_threshold(M - L + (1.0/mu)*Y, lam/mu)

        # dual update
        R = M - L - S
        Y = Y + mu * R

        # stop
        err = np.linalg.norm(R, ord='fro') / normM
        if err < tol:
            break

        mu *= rho

    return L.astype(np.float32), S.astype(np.float32)

# def mat_to_amp(data_mat):
#     """
#     Calculate amplitude of raw WiFi CSI data, then return RPCA sparse component S.
#     (Ÿàÿ±ŸàÿØ€å ÿ™ÿßÿ®ÿπ ÿ™ÿ∫€å€åÿ± ŸÜ⁄©ÿ±ÿØŸá)
#     """
#     var_length = data_mat["trace"].shape[0]
#     data_csi_amp = [abs(data_mat["trace"][var_t][0][0][0][-1]) for var_t in range(var_length)]
#     data_csi_amp = np.array(data_csi_amp, dtype=np.float32)

#     was_1d = (data_csi_amp.ndim == 1)
#     M = data_csi_amp[:, None] if was_1d else data_csi_amp

#     _, S = _rpca_ialm(
#         M,
#         lam=RPCA_LAMBDA,
#         mu=RPCA_MU_INIT,
#         rho=RPCA_RHO,
#         max_iter=RPCA_MAX_ITER,
#         tol=RPCA_TOL
#     )

#     if was_1d:
#         S = S[:, 0]

#     return S

def mat_to_amp(data_mat):
    """
    Calculate amplitude of raw WiFi CSI data, then return RPCA low-rank component L.
    (Ÿàÿ±ŸàÿØ€å ÿ™ÿßÿ®ÿπ ÿ™ÿ∫€å€åÿ± ŸÜ⁄©ÿ±ÿØŸá)
    """
    var_length = data_mat["trace"].shape[0]
    data_csi_amp = [abs(data_mat["trace"][var_t][0][0][0][-1]) for var_t in range(var_length)]
    data_csi_amp = np.array(data_csi_amp, dtype=np.float32)

    was_1d = (data_csi_amp.ndim == 1)
    M = data_csi_amp[:, None] if was_1d else data_csi_amp

    L, _ = _rpca_ialm(
        M,
        lam=RPCA_LAMBDA,
        mu=RPCA_MU_INIT,
        rho=RPCA_RHO,
        max_iter=RPCA_MAX_ITER,
        tol=RPCA_TOL
    )

    if was_1d:
        L = L[:, 0]

    return L




def parse_args():
    """
    Parse arguments from input.
    """
    var_args = argparse.ArgumentParser()
    var_args.add_argument("--dir_mat", default="/kaggle/input/wimans/wifi_csi/mat", type=str)
    var_args.add_argument("--dir_amp", default="/kaggle/input/wimans/wifi_csi/amp", type=str)
    return var_args.parse_args()

# if __name__ == "__main__":
#     var_args = parse_args()
#     extract_csi_amp(var_dir_mat=var_args.dir_mat, var_dir_amp=var_args.dir_amp)


---
Cell 5: that.py (WiFi-based Model THAT)
---
---

In [7]:
"""
[file]          that.py
[description]   implement and evaluate WiFi-based model THAT
                https://github.com/windofshadow/THAT
"""

# All necessary libraries are imported in Cell 1.
# from train import train   --> Defined in Cell 6.
# from preset import preset --> Defined in Cell 2.

class Gaussian_Position(torch.nn.Module):
    def __init__(self, var_dim_feature, var_dim_time, var_num_gaussian=10):
        super(Gaussian_Position, self).__init__()
        var_embedding = torch.zeros([var_num_gaussian, var_dim_feature], dtype=torch.float)
        self.var_embedding = torch.nn.Parameter(var_embedding, requires_grad=True)
        torch.nn.init.xavier_uniform_(self.var_embedding)
        var_position = torch.arange(0.0, var_dim_time).unsqueeze(1).repeat(1, var_num_gaussian)
        self.var_position = torch.nn.Parameter(var_position, requires_grad=False)
        var_mu = torch.arange(0.0, var_dim_time, var_dim_time/var_num_gaussian).unsqueeze(0)
        self.var_mu = torch.nn.Parameter(var_mu, requires_grad=True)
        var_sigma = torch.tensor([50.0] * var_num_gaussian).unsqueeze(0)
        self.var_sigma = torch.nn.Parameter(var_sigma, requires_grad=True)

    def calculate_pdf(self, var_position, var_mu, var_sigma):
        var_pdf = var_position - var_mu
        var_pdf = - var_pdf * var_pdf
        var_pdf = var_pdf / var_sigma / var_sigma / 2
        var_pdf = var_pdf - torch.log(var_sigma)
        return var_pdf

    def forward(self, var_input):
        var_pdf = self.calculate_pdf(self.var_position, self.var_mu, self.var_sigma)
        var_pdf = torch.softmax(var_pdf, dim=-1)
        var_position_encoding = torch.matmul(var_pdf, self.var_embedding)
        var_output = var_input + var_position_encoding.unsqueeze(0)
        return var_output

class Encoder(torch.nn.Module):
    def __init__(self, var_dim_feature, var_num_head=10, var_size_cnn=[1, 3, 5]):
        super(Encoder, self).__init__()
        self.layer_norm_0 = torch.nn.LayerNorm(var_dim_feature, eps=1e-6)
        self.layer_attention = torch.nn.MultiheadAttention(var_dim_feature, var_num_head, batch_first=True)
        self.layer_dropout_0 = torch.nn.Dropout(0.1)
        self.layer_norm_1 = torch.nn.LayerNorm(var_dim_feature, 1e-6)
        layer_cnn = []
        for var_size in var_size_cnn:
            layer = torch.nn.Sequential(
                torch.nn.Conv1d(var_dim_feature, var_dim_feature, var_size, padding="same"),
                torch.nn.BatchNorm1d(var_dim_feature),
                torch.nn.Dropout(0.1),
                torch.nn.LeakyReLU()
            )
            layer_cnn.append(layer)
        self.layer_cnn = torch.nn.ModuleList(layer_cnn)
        self.layer_dropout_1 = torch.nn.Dropout(0.1)

    def forward(self, var_input):
        var_t = var_input
        var_t = self.layer_norm_0(var_t)
        var_t, _ = self.layer_attention(var_t, var_t, var_t)
        var_t = self.layer_dropout_0(var_t)
        var_t = var_t + var_input
        var_s = self.layer_norm_1(var_t)
        var_s = torch.permute(var_s, (0, 2, 1))
        var_c = torch.stack([layer(var_s) for layer in self.layer_cnn], dim=0)
        var_s = torch.sum(var_c, dim=0) / len(self.layer_cnn)
        var_s = self.layer_dropout_1(var_s)
        var_s = torch.permute(var_s, (0, 2, 1))
        var_output = var_s + var_t
        return var_output

class THAT(torch.nn.Module):
    def __init__(self, var_x_shape, var_y_shape):
        super(THAT, self).__init__()
        var_dim_feature = var_x_shape[-1]
        var_dim_time = var_x_shape[-2]
        var_dim_output = var_y_shape[-1]
        # Left branch
        self.layer_left_pooling = torch.nn.AvgPool1d(kernel_size=20, stride=20)
        self.layer_left_gaussian = Gaussian_Position(var_dim_feature, var_dim_time // 20)
        var_num_left = 4
        var_dim_left = var_dim_feature
        self.layer_left_encoder = torch.nn.ModuleList([
            Encoder(var_dim_feature=var_dim_left, var_num_head=10, var_size_cnn=[1, 3, 5])
            for _ in range(var_num_left)
        ])
        self.layer_left_norm = torch.nn.LayerNorm(var_dim_left, eps=1e-6)
        self.layer_left_cnn_0 = torch.nn.Conv1d(in_channels=var_dim_left, out_channels=128, kernel_size=8)
        self.layer_left_cnn_1 = torch.nn.Conv1d(in_channels=var_dim_left, out_channels=128, kernel_size=16)
        self.layer_left_dropout = torch.nn.Dropout(0.5)
        # Right branch
        self.layer_right_pooling = torch.nn.AvgPool1d(kernel_size=20, stride=20)
        var_num_right = 1
        var_dim_right = var_dim_time // 20
        self.layer_right_encoder = torch.nn.ModuleList([
            Encoder(var_dim_feature=var_dim_right, var_num_head=10, var_size_cnn=[1, 2, 3])
            for _ in range(var_num_right)
        ])
        self.layer_right_norm = torch.nn.LayerNorm(var_dim_right, eps=1e-6)
        self.layer_right_cnn_0 = torch.nn.Conv1d(in_channels=var_dim_right, out_channels=16, kernel_size=2)
        self.layer_right_cnn_1 = torch.nn.Conv1d(in_channels=var_dim_right, out_channels=16, kernel_size=4)
        self.layer_right_dropout = torch.nn.Dropout(0.5)
        self.layer_leakyrelu = torch.nn.LeakyReLU()
        self.layer_output = torch.nn.Linear(256 + 32, var_dim_output)

    def forward(self, var_input):
        var_t = var_input  # shape: (batch_size, time_steps, features)
        # Left branch
        var_left = torch.permute(var_t, (0, 2, 1))
        var_left = self.layer_left_pooling(var_left)
        var_left = torch.permute(var_left, (0, 2, 1))
        var_left = self.layer_left_gaussian(var_left)
        for layer in self.layer_left_encoder:
            var_left = layer(var_left)
        var_left = self.layer_left_norm(var_left)
        var_left = torch.permute(var_left, (0, 2, 1))
        var_left_0 = self.layer_leakyrelu(self.layer_left_cnn_0(var_left))
        var_left_1 = self.layer_leakyrelu(self.layer_left_cnn_1(var_left))
        var_left_0 = torch.sum(var_left_0, dim=-1)
        var_left_1 = torch.sum(var_left_1, dim=-1)
        var_left = torch.concat([var_left_0, var_left_1], dim=-1)
        var_left = self.layer_left_dropout(var_left)
        # Right branch
        var_right = torch.permute(var_t, (0, 2, 1))
        var_right = self.layer_right_pooling(var_right)
        for layer in self.layer_right_encoder:
            var_right = layer(var_right)
        var_right = self.layer_right_norm(var_right)
        var_right = torch.permute(var_right, (0, 2, 1))
        var_right_0 = self.layer_leakyrelu(self.layer_right_cnn_0(var_right))
        var_right_1 = self.layer_leakyrelu(self.layer_right_cnn_1(var_right))
        var_right_0 = torch.sum(var_right_0, dim=-1)
        var_right_1 = torch.sum(var_right_1, dim=-1)
        var_right = torch.concat([var_right_0, var_right_1], dim=-1)
        var_right = self.layer_right_dropout(var_right)
        # Concatenate branches
        var_t = torch.concat([var_left, var_right], dim=-1)
        var_output = self.layer_output(var_t)
        return var_output

def run_that(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat=10, init_model=None):
    """
    Run WiFi-based model THAT.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_train_x = data_train_x.reshape(data_train_x.shape[0], data_train_x.shape[1], -1)
    data_test_x = data_test_x.reshape(data_test_x.shape[0], data_test_x.shape[1], -1)
    var_x_shape, var_y_shape = data_train_x[0].shape, data_train_y[0].reshape(-1).shape
    data_train_set = TensorDataset(torch.from_numpy(data_train_x), torch.from_numpy(data_train_y))
    data_test_set = TensorDataset(torch.from_numpy(data_test_x), torch.from_numpy(data_test_y))
    
    result = {}
    result_accuracy = []
    result_time_train = []
    result_time_test = []
    
    # var_macs, var_params = get_model_complexity_info(THAT(var_x_shape, var_y_shape), var_x_shape, as_strings=False)
    # print("Parameters:", var_params, "- FLOPs:", var_macs * 2)
    
    for var_r in range(var_repeat):
        print("Repeat", var_r)
        torch.random.manual_seed(var_r + 39)
        if init_model is not None:
            model_that = init_model
            lr2 = preset["nn"]["lr"] /10
        else:
            model_that = THAT(var_x_shape, var_y_shape).to(device)
            lr2 = preset["nn"]["lr"]

        optimizer = torch.optim.Adam(model_that.parameters(), lr=lr2, weight_decay=0)
        loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4] * var_y_shape[-1]).to(device))
        var_time_0 = time.time()
        
        # Train
        var_best_weight = train(model=model_that, optimizer=optimizer, loss=loss, 
                                  data_train_set=data_train_set, data_test_set=data_test_set,
                                  var_threshold=preset["nn"]["threshold"],
                                  var_batch_size=preset["nn"]["batch_size"],
                                  var_epochs=preset["nn"]["epoch"],
                                  device=device)
        var_time_1 = time.time()
        
        # Test
        model_that.load_state_dict(var_best_weight)
        with torch.no_grad():
            predict_test_y = model_that(torch.from_numpy(data_test_x).to(device))
        predict_test_y = (torch.sigmoid(predict_test_y) > preset["nn"]["threshold"]).float()
        predict_test_y = predict_test_y.detach().cpu().numpy()
        var_time_2 = time.time()
        
        # Evaluate
        data_test_y_c = data_test_y.reshape(-1, data_test_y.shape[-1])
        predict_test_y_c = predict_test_y.reshape(-1, data_test_y.shape[-1])
        result_acc = accuracy_score(data_test_y_c.astype(int), predict_test_y_c.astype(int))
        result_dict = classification_report(data_test_y_c, predict_test_y_c, digits=6, zero_division=0, output_dict=True)
        result["repeat_" + str(var_r)] = result_dict
        result_accuracy.append(result_acc)
        result_time_train.append(var_time_1 - var_time_0)
        result_time_test.append(var_time_2 - var_time_1)
        print("repeat_" + str(var_r), result_accuracy)
        print(result)
    
    result["accuracy"] = {"avg": np.mean(result_accuracy), "std": np.std(result_accuracy)}
    result["time_train"] = {"avg": np.mean(result_time_train), "std": np.std(result_time_train)}
    result["time_test"] = {"avg": np.mean(result_time_test), "std": np.std(result_time_test)}
    # result["complexity"] = {"parameter": var_params, "flops": var_macs * 2}
    return result


---
Cell7: for RESNET18 Model
---
---

In [8]:
# import os
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# import torch._dynamo
# torch._dynamo.config.suppress_errors = True
# import time
# import torch
# torch.cuda.empty_cache()
# import numpy as np
# from torch.utils.data import TensorDataset, DataLoader
# from sklearn.metrics import accuracy_score, classification_report
# import torchvision.models as models
# from copy import deepcopy

# torch.set_float32_matmul_precision("high")
# torch._dynamo.config.cache_size_limit = 65536

# # ŸÅÿ±ÿ∂ ŸÖ€å‚Äå⁄©ŸÜ€åŸÖ preset ŸÇÿ®ŸÑÿßŸã ÿ™ÿπÿ±€åŸÅ ÿ¥ÿØŸá ÿ®ÿßÿ¥Ÿá
# # preset = { "nn": {"lr": 1e-3, "epoch": 10, "batch_size": 4, "threshold": 0.5}, ... }

# class ResNet18Model(torch.nn.Module):
#     def __init__(self, var_x_shape, var_y_shape):
#         super(ResNet18Model, self).__init__()
#         model_resnet = models.resnet18(weights=None)
#         model_resnet.conv1 = torch.nn.Conv2d(1, 64, 7, 3, 2, bias=False)
#         in_features_fc = model_resnet.fc.in_features  # ŸÖÿπŸÖŸàŸÑÿßŸã 512
#         out_features_fc = var_y_shape[-1]
#         model_resnet.fc = torch.nn.Linear(in_features_fc, out_features_fc)
#         self.resnet = model_resnet

#     def forward(self, var_input):
#         var_input = var_input.reshape(var_input.size(0), 1, 3000, 270)
#         return self.resnet(var_input)

# def run_resnet(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat=10, init_model=None):
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     var_x_shape = data_train_x[0].shape
#     var_y_shape = data_train_y[0].reshape(-1).shape

#     # ÿ™ÿ∫€å€åÿ± ÿ¥⁄©ŸÑ ÿØÿßÿØŸá‚ÄåŸáÿß ÿ±Ÿà€å CPU
#     data_train_x = data_train_x.reshape(data_train_x.shape[0], 1, data_train_x.shape[1],
#                                         data_train_x.shape[2]*data_train_x.shape[3]*data_train_x.shape[4])
#     data_test_x  = data_test_x.reshape(data_test_x.shape[0], 1, data_test_x.shape[1],
#                                        data_test_x.shape[2]*data_test_x.shape[3]*data_test_x.shape[4])
    
#     # ÿØ€åÿ™ÿßÿ≥ÿ™‚ÄåŸáÿß ÿ±Ÿà€å CPU
#     data_train_set = TensorDataset(torch.from_numpy(data_train_x).float(),
#                                    torch.from_numpy(data_train_y).float())
#     data_test_set  = TensorDataset(torch.from_numpy(data_test_x).float(),
#                                    torch.from_numpy(data_test_y).float())
    
#     result = {}
#     result_accuracy = []
#     result_time_train = []
#     result_time_test = []
    
#     for var_r in range(var_repeat):
#         print("Repeat", var_r)
#         torch.random.manual_seed(var_r + 39)
        
#         # ÿ≥ÿßÿÆÿ™ ŸÖÿØŸÑ Ÿà ÿßŸÜÿ™ŸÇÿßŸÑ ÿ®Ÿá GPU
#         if init_model is not None:
#             model_resnet = init_model
#             lr2 = preset["nn"]["lr"] /10
            
#         else:
#             model_resnet = ResNet18Model(var_x_shape, var_y_shape).to(device)
#             lr2 = preset["nn"]["lr"]

#         optimizer = torch.optim.Adam(model_resnet.parameters(), lr=lr2, weight_decay=0)
#         loss_func = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([6] * var_y_shape[-1]).to(device))
        
#         # ÿ™ÿßÿ®ÿπ ÿ¢ŸÖŸàÿ≤ÿ¥ ÿØÿßÿÆŸÑ€åÿõ ÿØ€åÿ™ÿß ÿ±Ÿà€å CPU ÿ®ÿßŸÇ€å ŸÖ€å‚ÄåŸÖŸàŸÜŸá Ÿà ŸÅŸÇÿ∑ ŸáŸÜ⁄ØÿßŸÖ ŸÖÿ≠ÿßÿ≥ÿ®Ÿá batch ÿ®Ÿá GPU ŸÖ€åÿ±Ÿá
#         def train_inner():
#             train_loader = DataLoader(data_train_set, preset["nn"]["batch_size"], shuffle=True, pin_memory=False)
#             test_loader = DataLoader(data_test_set, preset["nn"]["batch_size"], shuffle=False, pin_memory=False)
#             best_accuracy = 0
#             best_weight = None
            
#             for epoch in range(preset["nn"]["epoch"]):
#                 t0 = time.time()
#                 model_resnet.train()
#                 # ŸÖÿ™ÿ∫€åÿ±Ÿáÿß€å ŸÖÿ±ÿ®Ÿàÿ∑ ÿ®Ÿá ÿ¢ÿÆÿ±€åŸÜ batch ÿ¢ŸÖŸàÿ≤ÿ¥
#                 last_train_loss = None
#                 last_train_acc = None
#                 for batch in train_loader:
#                     batch_x, batch_y = batch
#                     batch_x = batch_x.to(device)
#                     batch_y = batch_y.to(device)
#                     outputs = model_resnet(batch_x)
#                     loss_val = loss_func(outputs, batch_y.reshape(batch_y.shape[0], -1).float())
#                     optimizer.zero_grad()
#                     loss_val.backward()
#                     optimizer.step()
#                     last_train_loss = loss_val.item()
#                     # ŸÖÿ≠ÿßÿ≥ÿ®Ÿá ÿØŸÇÿ™ ÿ¢ÿÆÿ±€åŸÜ batch ÿ¢ŸÖŸàÿ≤ÿ¥
#                     train_preds = (torch.sigmoid(outputs) > preset["nn"]["threshold"]).float()
#                     last_train_acc = accuracy_score(batch_y.reshape(batch_y.shape[0], -1).detach().cpu().numpy().astype(int),
#                                                     train_preds.detach().cpu().numpy().astype(int))
                
#                 # ÿßÿ±ÿ≤€åÿßÿ®€å ÿ±Ÿà€å ÿØ€åÿ™ÿßÿ≥ÿ™ ÿ™ÿ≥ÿ™ ÿ®Ÿá ÿµŸàÿ±ÿ™ batch ÿ®Ÿá batch
#                 model_resnet.eval()
#                 all_preds = []
#                 all_labels = []
#                 test_loss_val = None
#                 with torch.no_grad():
#                     for t_batch in test_loader:
#                         t_x, t_y = t_batch
#                         t_x = t_x.to(device)
#                         outputs_test = model_resnet(t_x)
#                         outputs_test = (torch.sigmoid(outputs_test) > preset["nn"]["threshold"]).float()
#                         all_preds.append(outputs_test.detach().cpu().numpy())
#                         all_labels.append(t_y.cpu().numpy())  # ÿß€åŸÜÿ¨ÿß ÿ™ÿ∫€å€åÿ± ÿØÿßÿØ€åŸÖ
#                 preds_cat = np.vstack(all_preds)
#                 labels_cat = np.vstack(all_labels)
#                 print("preds_cat",preds_cat.shape)
#                 # ÿ™ÿ®ÿØ€åŸÑ ÿ®Ÿá ÿ¥⁄©ŸÑ (n, 6, 5)
                
#                 # preds_cat = preds_cat.reshape(-1, 6, 5)
#                 # labels_cat = labels_cat.reshape(-1, 6, 5)

#                 preds_cat = preds_cat.reshape(-1, 6)
#                 labels_cat = labels_cat.reshape(-1, 6)
                
#                 # ÿ®ÿ±ÿß€å ŸÖÿ≠ÿßÿ≥ÿ®Ÿá ÿØŸÇÿ™ÿå ŸÖÿ≥ÿ∑ÿ≠ ŸÖ€å‚Äå⁄©ŸÜ€åŸÖ
#                 test_acc = accuracy_score(labels_cat.reshape(labels_cat.shape[0], -1).astype(int),
#                                           preds_cat.reshape(preds_cat.shape[0], -1).astype(int))
#                 epoch_time = time.time() - t0
#                 print(f"Epoch {epoch}/{preset['nn']['epoch']} - "
#                       f"Train Loss: {(last_train_loss if last_train_loss is not None else 0.0):.6f}, "
#                       f"Train Acc: {(last_train_acc if last_train_acc is not None else 0.0):.6f}, "
#                       f"Test Loss: {(test_loss_val if test_loss_val is not None else 0.0):.6f}, "
#                       f"Test Acc: {(test_acc if test_acc is not None else 0.0):.6f} - "
#                       f"Time: {epoch_time:.4f}s")

#                 if test_acc > best_accuracy:
#                     best_accuracy = test_acc
#                     print('-----***-----')
#                     print(best_accuracy)
#                     best_weight = deepcopy(model_resnet.state_dict())
#             return best_weight
        
#         t0_run = time.time()
#         best_weight = train_inner()
#         t1_run = time.time()
        
#         torch.save(model_resnet.state_dict(), f"{name_run}_model_final.pt")
#         model_resnet.load_state_dict(best_weight)
#         torch.save(model_resnet.state_dict(), f"{name_run}_best_model.pt")

#         # bad age niaz bod load koni
#         # model_resnet = ResNet18Model(var_x_shape, var_y_shape).to(device)
#         # model_resnet.load_state_dict(torch.load("resnet_model_repeat0.pt"))
#         # model_resnet.eval()

        
#         # ÿßÿ±ÿ≤€åÿßÿ®€å ŸÜŸáÿß€å€å ŸÖÿØŸÑ ÿ±Ÿà€å ÿØ€åÿ™ÿßÿ≥ÿ™ ÿ™ÿ≥ÿ™ (ÿßÿ≥ÿ™ŸÅÿßÿØŸá ÿßÿ≤ batchŸáÿß€å ⁄©Ÿà⁄Ü⁄©)
#         model_resnet.eval()
#         all_preds = []
#         test_loader_final = DataLoader(data_test_set, preset["nn"]["batch_size"], shuffle=False, pin_memory=False)
#         with torch.no_grad():
#             for batch in test_loader_final:
#                 batch_x, _ = batch
#                 batch_x = batch_x.to(device)
#                 all_preds.append(model_resnet(batch_x))
#         preds_all = torch.cat(all_preds, dim=0)
#         preds_final = (torch.sigmoid(preds_all) > preset["nn"]["threshold"]).float().detach().cpu().numpy()
#         t2_run = time.time()
        
#         data_test_y_np = data_test_y.reshape(-1, data_test_y.shape[-1])
#         preds_final = preds_final.reshape(-1, data_test_y.shape[-1])
#         acc_final = accuracy_score(data_test_y_np.astype(int), preds_final.astype(int))
#         result[f"repeat_{var_r}"] = {"accuracy": acc_final}
#         result_accuracy.append(acc_final)
#         result_time_train.append(t1_run - t0_run)
#         result_time_test.append(t2_run - t1_run)
#         print("Repeat", var_r, "Final Test Accuracy:", acc_final)
    
#     result["accuracy"] = {"avg": np.mean(result_accuracy), "std": np.std(result_accuracy)}
#     result["time_train"] = {"avg": np.mean(result_time_train), "std": np.std(result_time_train)}
#     result["time_test"] = {"avg": np.mean(result_time_test), "std": np.std(result_time_test)}
#     return result


---
Cell 9: train.py
---
---

In [9]:
"""
[file]          train.py
[description]   function to train WiFi-based models
"""

# All necessary libraries are imported in Cell 1.

torch.set_float32_matmul_precision("high")
torch._dynamo.config.cache_size_limit = 65536

def train(model, optimizer, loss, data_train_set, data_test_set, var_threshold, var_batch_size, var_epochs, device):
    """
    Generic training function for WiFi-based models.
    """
    # ÿØ€åÿ™ÿß ÿ±Ÿà ÿ±Ÿà€å CPU ŸÜ⁄ØŸá ŸÖ€å‚ÄåÿØÿßÿ±€åŸÖ (pin_memory=False)
    data_train_loader = DataLoader(data_train_set, var_batch_size, shuffle=True, pin_memory=False)
    data_test_loader = DataLoader(data_test_set, batch_size=len(data_test_set), shuffle=False, pin_memory=False)
    
    var_best_accuracy = -1.0
    var_best_weight   = deepcopy(model.state_dict())
    
    
    for var_epoch in range(var_epochs):
        var_time_e0 = time.time()
        model.train()
        for data_batch in data_train_loader:
            data_batch_x, data_batch_y = data_batch
            # ÿßŸÜÿ™ŸÇÿßŸÑ ŸÖŸàŸÇÿ™€å ÿØÿßÿØŸá ÿ®Ÿá GPU ŸÅŸÇÿ∑ ÿ®ÿ±ÿß€å forward pass
            data_batch_x = data_batch_x.to(device)
            data_batch_y = data_batch_y.to(device)
            predict_train_y = model(data_batch_x)
            var_loss_train = loss(predict_train_y, data_batch_y.reshape(data_batch_y.shape[0], -1).float())
            optimizer.zero_grad()
            var_loss_train.backward()
            optimizer.step()
        
        # ŸÖÿ≠ÿßÿ≥ÿ®Ÿá ÿØŸÇÿ™ ÿ±Ÿà€å ÿ¢ÿÆÿ±€åŸÜ batch Ÿà ÿßŸÜÿ™ŸÇÿßŸÑ ŸÜÿ™ÿß€åÿ¨ ÿ®Ÿá CPU
        predict_train_y = (torch.sigmoid(predict_train_y) > var_threshold).float()
        data_batch_y = data_batch_y.detach().cpu().numpy()
        predict_train_y = predict_train_y.detach().cpu().numpy()
        
        predict_train_y = predict_train_y.reshape(-1, data_batch_y.shape[-1])
        data_batch_y = data_batch_y.reshape(-1, data_batch_y.shape[-1])
        var_accuracy_train = accuracy_score(data_batch_y.astype(int), predict_train_y.astype(int))
        
        model.eval()
        with torch.no_grad():
            data_test_x, data_test_y = next(iter(data_test_loader))
            # ÿßŸÜÿ™ŸÇÿßŸÑ ŸÖŸàŸÇÿ™€å ÿØ€åÿ™ÿß ÿ™ÿ≥ÿ™ ÿ®Ÿá GPU ÿ®ÿ±ÿß€å ŸÖÿ≠ÿßÿ≥ÿ®ÿßÿ™
            data_test_x = data_test_x.to(device)
            data_test_y = data_test_y.to(device)
            
            predict_test_y = model(data_test_x)
            var_loss_test = loss(predict_test_y, data_test_y.reshape(data_test_y.shape[0], -1).float())
            
            predict_test_y = (torch.sigmoid(predict_test_y) > var_threshold).float()
            
            # ÿßŸÜÿ™ŸÇÿßŸÑ ŸÜÿ™ÿß€åÿ¨ ÿ®Ÿá CPU ÿ®ÿ±ÿß€å ÿßÿ±ÿ≤€åÿßÿ®€å
            data_test_y = data_test_y.detach().cpu().numpy()
            predict_test_y = predict_test_y.detach().cpu().numpy()
            
            predict_test_y = predict_test_y.reshape(-1, data_test_y.shape[-1])
            data_test_y = data_test_y.reshape(-1, data_test_y.shape[-1])
            var_accuracy_test = accuracy_score(data_test_y.astype(int), predict_test_y.astype(int))
        
        print(f"Epoch {var_epoch}/{var_epochs}",
              "- %.6fs"%(time.time() - var_time_e0),
              "- Loss %.6f"%var_loss_train.cpu(),
              "- Accuracy %.6f"%var_accuracy_train,
              "- Test Loss %.6f"%var_loss_test.cpu(),
              "- Test Accuracy %.6f"%var_accuracy_test)
            
        if var_accuracy_test > var_best_accuracy:
            var_best_accuracy = var_accuracy_test
            print('-----***-----')
            print(var_best_accuracy)
            var_best_weight = deepcopy(model.state_dict())

    torch.save(model.state_dict(), f"{name_run}_model_final.pt")
    torch.save(var_best_weight, f"{name_run}_best_model.pt")

    
    return var_best_weight



# === importsŸê ŸÑÿßÿ≤ŸÖ ÿ±ÿß €å⁄©‚Äåÿ®ÿßÿ± ÿ®ÿßŸÑÿß€å ŸÅÿß€åŸÑ ÿßÿ∂ÿßŸÅŸá ⁄©ŸÜ ===
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

# ---------- ÿ™ÿßÿ®ÿπ ⁄©ŸÖ⁄©€å ----------
def save_confusion_matrix(model, data_loader, threshold, device, pdf_path):
    """
    Runs the model on `data_loader`, builds a confusion matrix and writes it to `pdf_path`.
    """
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for xb, yb in data_loader:
            xb = xb.to(device)
            logits = model(xb)

            preds = (torch.sigmoid(logits) > threshold).float().cpu().numpy().ravel()
            yb    = yb.cpu().numpy().ravel()

            y_true.extend(yb)
            y_pred.extend(preds)

    cm  = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots()
    ConfusionMatrixDisplay(cm).plot(ax=ax)
    ax.set_title("Confusion Matrix ‚Äì Test")

    with PdfPages(pdf_path) as pdf:
        pdf.savefig(fig)
    plt.close(fig)
# ---------------------------------


---
Cell 11: run.py
---
---

In [10]:
import gc
import torch
gc.collect()           
torch.cuda.empty_cache()  
torch.cuda.ipc_collect()  




"""
[file]          run.py
[description]   run WiFi-based models and optionally save a multiclass confusion matrix
"""

import argparse
import json
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib.backends.backend_pdf import PdfPages

# from preset import preset, name_run
# from load_data import load_data_x, load_data_y, encode_data_y
# from lstm import run_lstm, LSTMM
# from bilstm import run_bilstm, BiLSTMM
# from that import run_that, THAT
# from resnet import run_resnet, ResNet18Model
# from strf import run_strf  # if you have the ST-RF implementation

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",   default=preset["model"],  type=str)
    parser.add_argument("--task",    default=preset["task"],   type=str)
    parser.add_argument("--repeat",  default=preset["repeat"], type=int)
    parser.add_argument("--save_cm", action="store_true",
                        help="Save a multiclass confusion matrix of the best model to PDF")
    args, _ = parser.parse_known_args()
    return args

def save_multiclass_confusion_matrix(model, data_loader, device, pdf_path, num_classes):
    """
    Given a model that outputs one-hot logits for a multiclass task,
    convert to predicted classes via argmax, then plot and save a
    num_classes √ó num_classes confusion matrix to pdf_path.
    """
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for xb, yb in data_loader:
            xb = xb.to(device)
            logits = model(xb)
            # predicted class is index of max logit
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            trues = torch.argmax(yb, dim=1).cpu().numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(trues.tolist())

    labels = list(range(num_classes))
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    disp = ConfusionMatrixDisplay(cm, display_labels=labels)
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(ax=ax, xticks_rotation="vertical")
    ax.set_title("Confusion Matrix")
    with PdfPages(pdf_path) as pdf:
        pdf.savefig(fig)
    plt.close(fig)

def run():
    args       = parse_args()
    var_model  = args.model
    var_task   = args.task
    var_repeat = args.repeat

    # --- Load and encode the data ---
    data_pd_y = load_data_y(
        preset["path"]["data_y"],
        var_environment=preset["data"]["environment"],
        var_wifi_band=preset["data"]["wifi_band"],
        var_num_users=preset["data"]["num_users"]
    )
    labels = data_pd_y["label"].tolist()
    data_x = load_data_x(preset["path"]["data_x"], labels)
    data_y = encode_data_y(data_pd_y, var_task)

    train_x, test_x, train_y, test_y = train_test_split(
        data_x, data_y, test_size=0.2, shuffle=True, random_state=39
    )

    # --- Select which model runner to use ---
    if var_model == "ST-RF":
        from strf import run_strf
        run_model = run_strf
    elif var_model == "LSTM":
        run_model = run_lstm
    elif var_model == "bi-LSTM":
        run_model = run_bilstm
    elif var_model == "THAT":
        run_model = run_that
    elif var_model == "ResNet18":
        run_model = run_resnet
    else:
        raise ValueError(f"Unknown model: {var_model}")

    # --- Train and evaluate ---
    print(f"Running model: {var_model}")
    result = run_model(train_x, train_y, test_x, test_y, var_repeat)
    result["model"] = var_model
    result["task"]  = var_task
    result["data"]  = preset["data"]
    result["nn"]    = preset["nn"]
    print(result)

    # --- Save results to JSON ---
    # with open(preset["path"]["save"], "w") as f:
    #     json.dump(result, f, indent=4)

    # # --- Optionally save a multiclass confusion matrix ---
    # # if args.save_cm:
    # if Confusion_matrix == 1:
    #     # 1) completely release GPU memory used for training
    #     del run_model                      # if 'model' from training is still in scope
    #     torch.cuda.empty_cache()
    #     torch.cuda.ipc_collect()
    
    #     # 2) reshape input only if the network is sequence‚Äëbased
    #     if var_model in ("LSTM", "bi-LSTM", "THAT"):
    #         test_x_cm = test_x.reshape(test_x.shape[0], test_x.shape[1], -1)
    #     else:                           # ResNet18, ST‚ÄëRF
    #         test_x_cm = test_x
    
    #     # 3) build the *same* architecture on CPU and load its weights
    #     device_cm = torch.device("cpu")
    #     if var_model == "LSTM":
    #         model_cm = LSTMM(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
    #     elif var_model == "bi-LSTM":
    #         model_cm = BiLSTMM(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
    #     elif var_model == "THAT":
    #         model_cm = THAT(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
    #     elif var_model == "ResNet18":
    #         model_cm = ResNet18Model(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
    #     else:
    #         raise ValueError(f"Confusion matrix not supported for {var_model}")
    
    #     # best_path = f"/kaggle/working/{name_run}_best_model.pt"
    #     # model_cm.load_state_dict(torch.load(best_path, map_location=device_cm))
    #     # model_cm.eval()
    
    #     # 4) DataLoader on CPU with a safe batch size
    #     test_ds = TensorDataset(torch.from_numpy(test_x_cm).float(),
    #                             torch.from_numpy(test_y).float())
    #     test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
    
    #     # 5) save the confusion matrix PDF
    #     num_classes = test_y.shape[1]
    #     pdf_name = f"{name_run}_confusion_matrix.pdf"
    #     save_multiclass_confusion_matrix(model_cm,test_loader,device_cm,pdf_name,num_classes)
    #     print(f"‚úÖ Saved confusion matrix (classes 0‚Äì{num_classes-1}) to {pdf_name}")
if __name__ == "__main__":
    print("start")
    run()


start
Running model: THAT
Repeat 0


  return F.conv1d(


Epoch 0/100 - 6.236140s - Loss 2.307054 - Accuracy 0.067708 - Test Loss 1.203212 - Test Accuracy 0.297966
-----***-----
0.2979664014146773
Epoch 1/100 - 5.133098s - Loss 0.900531 - Accuracy 0.276042 - Test Loss 0.682312 - Test Accuracy 0.549514
-----***-----
0.5495137046861185
Epoch 2/100 - 5.093784s - Loss 0.812541 - Accuracy 0.276042 - Test Loss 0.583160 - Test Accuracy 0.553050
-----***-----
0.553050397877984
Epoch 3/100 - 5.123330s - Loss 0.652456 - Accuracy 0.328125 - Test Loss 0.555206 - Test Accuracy 0.516799
Epoch 4/100 - 5.096751s - Loss 0.701374 - Accuracy 0.390625 - Test Loss 0.543965 - Test Accuracy 0.563660
-----***-----
0.5636604774535809
Epoch 5/100 - 5.100943s - Loss 0.642508 - Accuracy 0.354167 - Test Loss 0.528993 - Test Accuracy 0.560566
Epoch 6/100 - 5.093004s - Loss 0.552395 - Accuracy 0.432292 - Test Loss 0.494724 - Test Accuracy 0.559682
Epoch 7/100 - 5.125016s - Loss 0.573296 - Accuracy 0.416667 - Test Loss 0.505466 - Test Accuracy 0.557029
Epoch 8/100 - 5.08629

---
Cell 12: Few-shot Learning
---
---

In [11]:
# import gc
# import torch
# import shutil
# import json
# from sklearn.model_selection import train_test_split
# from torch.utils.data import DataLoader, TensorDataset
# import matplotlib.pyplot as plt
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# from matplotlib.backends.backend_pdf import PdfPages

# gc.collect()           
# torch.cuda.empty_cache()  
# torch.cuda.ipc_collect()

# # ---------- helper: save multiclass confusion matrix ------------------
# def save_multiclass_confusion_matrix(model, data_loader, pdf_path, num_classes):
#     """
#     Forward‚Äëpass on CPU, collect predictions, and write an N√óN confusion matrix
#     to a single‚Äëpage PDF (pdf_path).
#     """
#     model.eval()
#     y_true, y_pred = [], []
#     with torch.no_grad():
#         for xb, yb in data_loader:
#             logits = model(xb.cpu())                       # ensure CPU
#             preds  = torch.argmax(logits, dim=1).numpy()
#             trues  = torch.argmax(yb, dim=1).numpy()
#             y_pred.extend(preds.tolist())
#             y_true.extend(trues.tolist())

#     labels = list(range(num_classes))
#     cm  = confusion_matrix(y_true, y_pred, labels=labels)
#     disp = ConfusionMatrixDisplay(cm, display_labels=labels)
#     fig, ax = plt.subplots(figsize=(8, 8))
#     disp.plot(ax=ax, xticks_rotation="vertical")
#     ax.set_title("Few‚Äëshot Confusion Matrix")
#     with PdfPages(pdf_path) as pdf:
#         pdf.savefig(fig)
#     plt.close(fig)

# # -------------------- pick run_* function ------------------------------
# if preset["model"] == "ST-RF":
#     run_model = run_strf
# elif preset["model"] == "LSTM":
#     run_model = run_lstm
# elif preset["model"] == "bi-LSTM":
#     run_model = run_bilstm
# elif preset["model"] == "THAT":
#     run_model = run_that
# elif preset["model"] == "ResNet18":
#     run_model = run_resnet
# else:
#     raise ValueError(f"No few‚Äëshot implementation for {preset['model']}.")

# # ------------------------ load / split data ----------------------------
# data_pd_y = load_data_y(preset["path"]["data_y"],
#                         var_environment=[dest_env],
#                         var_wifi_band=preset["data"]["wifi_band"],
#                         var_num_users=preset["data"]["num_users"])

# labels_list = data_pd_y["label"].tolist()
# data_x = load_data_x(preset["path"]["data_x"], labels_list)
# data_y = encode_data_y(data_pd_y, preset["task"])

# train_x, test_x, train_y, test_y = train_test_split(
#     data_x, data_y, test_size=0.2, shuffle=True, random_state=39)

# # Few-shot sample size
# train_x = train_x[:few_shot_num_samples]
# train_y = train_y[:few_shot_num_samples]

# # ----------------------- few‚Äëshot training -----------------------------
# original_epochs = preset["nn"]["epoch"]
# preset["nn"]["epoch"] = few_shot_epochs

# # Load the best model weights
# best_model_path = f"{name_run}_best_model.pt"

# # Initialize the model 
# if preset["model"] == "LSTM":
#     model = LSTMM(train_x[0].reshape(train_x[0].shape[0], -1).shape, train_y[0].shape)  # Replace with your model initialization
#     # print('train_y_[0].shape:', train_y[0].shape)
#     # print('train_x_[0].shape:', train_x[0].reshape(train_x[0].shape[0], -1).shape)
# elif preset["model"] == "bi-LSTM":
#     model = BiLSTMM(train_x[0].reshape(train_x[0].shape[0], -1).shape, train_y[0].shape)  # Replace with your model initialization
# elif preset["model"] == "THAT":
#     model = THAT(train_x[0].reshape(train_x[0].shape[0], -1).shape, train_y[0].shape)  # Replace with your model initialization
# elif preset["model"] == "ResNet18":
#     model = ResNet18Model(train_x[0].reshape(train_x[0].shape[0], -1).shape, train_y[0].shape)  # Replace with your model initialization
# else:
#     raise ValueError(f"Model {preset['model']} not supported!")

# # Load the weights into the model
# model.load_state_dict(torch.load(best_model_path, map_location="cpu"))
# model = model.to('cuda')

# # Fine-tune the model on few-shot data (note: `run_model` should now return only the result)
# result = run_model(train_x, train_y, test_x, test_y, var_repeat=1, init_model=model)
# print(result)

# # --------------------- save few‚Äëshot checkpoints -----------------------
# # After fine-tuning, save the model
# torch.save(model.state_dict(), f"{name_run}_fewshot_final_model.pt")
# torch.save(model.state_dict(), f"{name_run}_fewshot_best_model.pt")

# # ------------------- confusion matrix on CPU ---------------------------
# if Confusion_matrix == 1 and preset["model"] != "ST-RF":

#     # reshape for sequence models
#     test_x_rs = (test_x.reshape(test_x.shape[0], test_x.shape[1], -1)
#                  if preset["model"] in ("LSTM", "bi-LSTM", "THAT") else test_x)

#     # instantiate identical architecture on CPU
#     if preset["model"] == "LSTM":
#         model_cpu = LSTMM(test_x_rs[0].shape, test_y[0].shape).cpu()
#     elif preset["model"] == "bi-LSTM":
#         model_cpu = BiLSTMM(test_x_rs[0].shape, test_y[0].shape).cpu()
#     elif preset["model"] == "THAT":
#         model_cpu = THAT(test_x_rs[0].shape, test_y[0].shape).cpu()
#     else:  # ResNet18
#         model_cpu = ResNet18Model(test_x_rs[0].shape, test_y[0].shape).cpu()

#     # load weights
#     model_cpu.load_state_dict(torch.load(f"{name_run}_fewshot_best_model.pt", map_location="cpu"))

#     # CPU DataLoader with a safe batch size
#     test_ds = TensorDataset(torch.from_numpy(test_x_rs).float(),
#                             torch.from_numpy(test_y).float())
#     test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

#     pdf_name = f"{name_run}_fewshot_confusion_matrix.pdf"
#     num_classes = test_y.shape[1]
#     save_multiclass_confusion_matrix(model_cpu, test_loader, pdf_name, num_classes)
#     print(f"‚úÖ Saved few‚Äëshot confusion matrix (classes 0‚Äì{num_classes-1}) to {pdf_name}")

# # ----------------------- restore & persist -----------------------------
# preset["nn"]["epoch"] = original_epochs

# # Save the final result to JSON
# with open("result_fewshot.json", "w") as f:
#     json.dump(result, f, indent=4)


In [12]:
# import os
# import argparse
# import numpy as np
# import pandas as pd
# import scipy.io as scio
# import time
# import torch
# import gc
# from numpy.linalg import svd
# from sklearn.preprocessing import OneHotEncoder
# from sklearn.model_selection import train_test_split
# from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
# from copy import deepcopy
# import json
# from torch.utils.data import TensorDataset, DataLoader
# import torch._dynamo
# from matplotlib.backends.backend_pdf import PdfPages
# import matplotlib.pyplot as plt

# # --- ÿ™ŸÜÿ∏€åŸÖÿßÿ™ ÿ≥€åÿ≥ÿ™ŸÖ€å ---
# torch.cuda.empty_cache()
# torch.set_float32_matmul_precision("high")

# # --------------------------
# # 1. ÿ™ŸÜÿ∏€åŸÖÿßÿ™ (Configuration)
# # --------------------------
# preset = {
#     "model": "THAT",          
#     "task": "activity",       
#     "repeat": 1,
#     "path": {
#         "data_x": "/kaggle/input/wimans/wifi_csi/amp",   
#         "data_y": "/kaggle/input/wimans/annotation.csv", 
#     },
#     "data": {
#         "num_users": ["0", "1", "2", "3", "4", "5"],  
#         "wifi_band": ["2.4"],                         
#         "environment": ["classroom"],                 
#         "length": 3000,
        
#         # 1.0 = 100% data (Full run) | 0.1 = 10% data (Quick test)
#         "subset_ratio": 0.5,  
#     },
#     "nn": {
#         "lr": 1e-3,           
#         "epoch": 80,          
#         "batch_size": 32,    
#         "threshold": 0.5,
#         "patience": 5,        
#         "factor": 0.5,        
#         "min_lr": 1e-6        
#     },
#     "encoding": {
#         "activity": {
#             "nan":      [0, 0, 0, 0, 0, 0, 0, 0, 0],
#             "nothing":  [1, 0, 0, 0, 0, 0, 0, 0, 0],
#             "walk":     [0, 1, 0, 0, 0, 0, 0, 0, 0],
#             "rotation": [0, 0, 1, 0, 0, 0, 0, 0, 0],
#             "jump":     [0, 0, 0, 1, 0, 0, 0, 0, 0],
#             "wave":     [0, 0, 0, 0, 1, 0, 0, 0, 0],
#             "lie_down": [0, 0, 0, 0, 0, 1, 0, 0, 0],
#             "pick_up":  [0, 0, 0, 0, 0, 0, 1, 0, 0],
#             "sit_down": [0, 0, 0, 0, 0, 0, 0, 1, 0],
#             "stand_up": [0, 0, 0, 0, 0, 0, 0, 0, 1],
#         },
#     },
# }

# # --------------------------
# # 2. ÿ™Ÿàÿßÿ®ÿπ RPCA Ÿà ŸÑŸàÿØ ÿØ€åÿ™ÿß
# # --------------------------
# def soft_threshold(x, epsilon):
#     return np.maximum(np.abs(x) - epsilon, 0) * np.sign(x)

# def robust_pca(M, max_iter=10, tol=1e-4):
#     n1, n2 = M.shape
#     lambda_param = 1 / np.sqrt(max(n1, n2))
#     Y = M / np.maximum(np.linalg.norm(M, 2), np.linalg.norm(M, np.inf) / lambda_param)
#     L = np.zeros_like(M)
#     S = np.zeros_like(M)
#     mu = 1.25 / np.linalg.norm(M, 2)
#     rho = 1.5
#     for i in range(max_iter):
#         temp_L = M - S + (1/mu) * Y
#         U, Sigma, Vt = svd(temp_L, full_matrices=False)
#         Sigma_thresh = soft_threshold(Sigma, 1/mu)
#         L_new = np.dot(U * Sigma_thresh, Vt)
#         temp_S = M - L_new + (1/mu) * Y
#         S_new = soft_threshold(temp_S, lambda_param/mu)
#         error = np.linalg.norm(M - L_new - S_new, 'fro') / np.linalg.norm(M, 'fro')
#         L = L_new; S = S_new
#         if error < tol: break
#         Y = Y + mu * (M - L - S)
#         mu = min(mu * rho, 1e7)
#     return L, S

# def load_data_y(var_path_data_y, var_environment=None, var_wifi_band=None, var_num_users=None):
#     data_pd_y = pd.read_csv(var_path_data_y, dtype=str)
#     if var_environment is not None: data_pd_y = data_pd_y[data_pd_y["environment"].isin(var_environment)]
#     if var_wifi_band is not None: data_pd_y = data_pd_y[data_pd_y["wifi_band"].isin(var_wifi_band)]
#     if var_num_users is not None: data_pd_y = data_pd_y[data_pd_y["number_of_users"].isin(var_num_users)]
#     return data_pd_y

# def load_data_x(var_path_data_x, var_label_list, use_rpca=True):
#     var_path_list = [os.path.join(var_path_data_x, var_label + ".npy") for var_label in var_label_list]
#     data_x = []
#     mode_str = "WITH RPCA" if use_rpca else "RAW DATA (No RPCA)"
#     print(f"Loading {len(var_path_list)} samples - Mode: {mode_str}...")
#     for i, var_path in enumerate(var_path_list):
#         if i % 100 == 0 and i > 0: print(f"Processing {i}/{len(var_path_list)}...")
#         data_csi = np.load(var_path) 
#         data_csi_2d = data_csi.reshape(data_csi.shape[0], -1)
#         target_len = preset["data"]["length"]
#         current_len = data_csi_2d.shape[0]
#         var_pad_length = target_len - current_len
#         if var_pad_length > 0: data_csi_pad = np.pad(data_csi_2d, ((0, var_pad_length), (0, 0)), mode='constant')
#         else: data_csi_pad = data_csi_2d[:target_len, :]
#         if use_rpca:
#             L, S = robust_pca(data_csi_pad)
#             final_sample = np.concatenate([L, S], axis=1) 
#         else:
#             final_sample = data_csi_pad
#         data_x.append(final_sample)
#     data_x = np.array(data_x)
#     return data_x

# def encode_data_y(data_pd_y, var_task):
#     if var_task == "activity": return encode_activity(data_pd_y, preset["encoding"]["activity"])
#     return encode_activity(data_pd_y, preset["encoding"]["activity"])

# def encode_activity(data_pd_y, var_encoding):
#     cols = [f"user_{i}_activity" for i in range(1, 7)]
#     data = data_pd_y[cols].to_numpy(copy=True).astype(str)
#     return np.array([[var_encoding[y] for y in sample] for sample in data])

# # --------------------------
# # 3. ŸÖÿØŸÑ THAT
# # --------------------------
# class Gaussian_Position(torch.nn.Module):
#     def __init__(self, var_dim_feature, var_dim_time, var_num_gaussian=10):
#         super(Gaussian_Position, self).__init__()
#         self.var_embedding = torch.nn.Parameter(torch.zeros([var_num_gaussian, var_dim_feature]), requires_grad=True)
#         torch.nn.init.xavier_uniform_(self.var_embedding)
#         self.var_position = torch.nn.Parameter(torch.arange(0.0, var_dim_time).unsqueeze(1).repeat(1, var_num_gaussian), requires_grad=False)
#         self.var_mu = torch.nn.Parameter(torch.arange(0.0, var_dim_time, var_dim_time/var_num_gaussian).unsqueeze(0), requires_grad=True)
#         self.var_sigma = torch.nn.Parameter(torch.tensor([50.0] * var_num_gaussian).unsqueeze(0), requires_grad=True)
#     def forward(self, var_input):
#         var_pdf = - (self.var_position - self.var_mu)**2 / (2 * self.var_sigma**2) - torch.log(self.var_sigma)
#         var_pdf = torch.softmax(var_pdf, dim=-1)
#         return var_input + torch.matmul(var_pdf, self.var_embedding).unsqueeze(0)

# class Encoder(torch.nn.Module):
#     def __init__(self, var_dim_feature, var_num_head=10, var_size_cnn=[1, 3, 5]):
#         super(Encoder, self).__init__()
#         self.layer_norm_0 = torch.nn.LayerNorm(var_dim_feature, eps=1e-6)
#         self.layer_attention = torch.nn.MultiheadAttention(var_dim_feature, var_num_head, batch_first=True)
#         self.layer_dropout_0 = torch.nn.Dropout(0.1)
#         self.layer_norm_1 = torch.nn.LayerNorm(var_dim_feature, 1e-6)
#         self.layer_cnn = torch.nn.ModuleList([torch.nn.Sequential(torch.nn.Conv1d(var_dim_feature, var_dim_feature, s, padding="same"), torch.nn.BatchNorm1d(var_dim_feature), torch.nn.Dropout(0.1), torch.nn.LeakyReLU()) for s in var_size_cnn])
#         self.layer_dropout_1 = torch.nn.Dropout(0.1)
#     def forward(self, var_input):
#         var_t = self.layer_norm_0(var_input)
#         var_t, _ = self.layer_attention(var_t, var_t, var_t)
#         var_t = self.layer_dropout_0(var_t) + var_input
#         var_s = self.layer_norm_1(var_t).permute(0, 2, 1)
#         var_c = torch.stack([l(var_s) for l in self.layer_cnn], dim=0)
#         var_s = self.layer_dropout_1((torch.sum(var_c, dim=0) / len(self.layer_cnn)).permute(0, 2, 1))
#         return var_s + var_t

# class THAT(torch.nn.Module):
#     def __init__(self, var_x_shape, var_y_shape):
#         super(THAT, self).__init__()
#         var_dim_feature, var_dim_time = var_x_shape[-1], var_x_shape[-2]
#         var_dim_output = var_y_shape[-1]
#         self.layer_left_pooling = torch.nn.AvgPool1d(kernel_size=20, stride=20)
#         self.layer_left_gaussian = Gaussian_Position(var_dim_feature, var_dim_time // 20)
#         self.layer_left_encoder = torch.nn.ModuleList([Encoder(var_dim_feature, 10, [1, 3, 5]) for _ in range(4)])
#         self.layer_left_norm = torch.nn.LayerNorm(var_dim_feature, eps=1e-6)
#         self.layer_left_cnn = torch.nn.ModuleList([torch.nn.Conv1d(var_dim_feature, 128, k) for k in [8, 16]])
#         self.layer_left_dropout = torch.nn.Dropout(0.5)
#         var_dim_right = var_dim_time // 20
#         self.layer_right_pooling = torch.nn.AvgPool1d(kernel_size=20, stride=20)
#         self.layer_right_encoder = torch.nn.ModuleList([Encoder(var_dim_right, 10, [1, 2, 3])])
#         self.layer_right_norm = torch.nn.LayerNorm(var_dim_right, eps=1e-6)
#         self.layer_right_cnn = torch.nn.ModuleList([torch.nn.Conv1d(var_dim_right, 16, k) for k in [2, 4]])
#         self.layer_right_dropout = torch.nn.Dropout(0.5)
#         self.layer_leakyrelu = torch.nn.LeakyReLU()
#         self.layer_output = torch.nn.Linear(256 + 32, var_dim_output)
#     def forward(self, var_input):
#         v_l = self.layer_left_gaussian(self.layer_left_pooling(var_input.permute(0, 2, 1)).permute(0, 2, 1))
#         for l in self.layer_left_encoder: v_l = l(v_l)
#         v_l = self.layer_left_norm(v_l).permute(0, 2, 1)
#         v_l = torch.cat([torch.sum(self.layer_leakyrelu(cnn(v_l)), dim=-1) for cnn in self.layer_left_cnn], dim=-1)
#         v_l = self.layer_left_dropout(v_l)
#         v_r = self.layer_right_pooling(var_input.permute(0, 2, 1))
#         for l in self.layer_right_encoder: v_r = l(v_r)
#         v_r = self.layer_right_norm(v_r).permute(0, 2, 1)
#         v_r = torch.cat([torch.sum(self.layer_leakyrelu(cnn(v_r)), dim=-1) for cnn in self.layer_right_cnn], dim=-1)
#         v_r = self.layer_right_dropout(v_r)
#         return self.layer_output(torch.cat([v_l, v_r], dim=-1))

# # --------------------------
# # 4. Training Loop
# # --------------------------
# def train(model, optimizer, loss_fn, train_loader, test_loader, threshold, epochs, device, model_path):
#     best_acc = -1.0
#     best_w = deepcopy(model.state_dict())
    
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, mode='max', factor=preset["nn"]["factor"], patience=preset["nn"]["patience"],
#         min_lr=preset["nn"]["min_lr"], verbose=True
#     )
    
#     for epoch in range(epochs):
#         t0 = time.time()
#         model.train()
        
#         # --- [MODIFIED] Using requested variable names ---
#         for data_batch_x, data_batch_y in train_loader:
#             data_batch_x = data_batch_x.to(device)
#             data_batch_y = data_batch_y.to(device)
            
#             optimizer.zero_grad()
            
#             predict_train_y = model(data_batch_x)
            
#             # --- [REQUESTED LINE] ---
#             loss_value = loss_fn(predict_train_y, data_batch_y.reshape(data_batch_y.shape[0], -1).float())
            
#             loss_value.backward()
#             optimizer.step()
            
#         model.eval()
#         with torch.no_grad():
#             tx, ty = next(iter(test_loader))
#             tx, ty = tx.to(device), ty.to(device)
#             pred_t = model(tx)
            
#             p_cls = (torch.sigmoid(pred_t) > threshold).float().cpu().numpy()
#             t_cls = ty.cpu().numpy()
#             acc = accuracy_score(t_cls.reshape(-1, t_cls.shape[-1]), p_cls.reshape(-1, t_cls.shape[-1]))
            
#         scheduler.step(acc)
#         current_lr = optimizer.param_groups[0]['lr']
#         print(f"Ep {epoch+1}/{epochs} | LR: {current_lr:.6f} | L_tr: {loss_value.item():.4f} | Acc: {acc:.4f}")
        
#         if acc > best_acc:
#             best_acc = acc
#             best_w = deepcopy(model.state_dict())
            
#     torch.save(best_w, model_path)
#     return best_w

# def save_multiclass_confusion_matrix(model, data_loader, device, pdf_path, num_classes, title_text):
#     model.eval()
#     y_true, y_pred = [], []
#     with torch.no_grad():
#         for xb, yb in data_loader:
#             xb = xb.to(device)
#             logits = model(xb) 
#             logits = logits.reshape(-1, num_classes) 
#             yb = yb.reshape(-1, num_classes)        
#             y_pred.extend(torch.argmax(logits, dim=1).cpu().numpy().tolist())
#             y_true.extend(torch.argmax(yb, dim=1).cpu().numpy().tolist())
    
#     labels = list(range(num_classes))
#     cm = confusion_matrix(y_true, y_pred, labels=labels)
#     disp = ConfusionMatrixDisplay(cm, display_labels=labels)
#     fig, ax = plt.subplots(figsize=(12, 12))
#     disp.plot(ax=ax, xticks_rotation="vertical", cmap='Blues')
#     ax.set_title(title_text)
#     with PdfPages(pdf_path) as pdf: pdf.savefig(fig)
#     plt.close(fig)

# # --------------------------
# # 5. ÿßÿ¨ÿ±ÿß
# # --------------------------
# def run_experiment(scenario_name, use_rpca):
#     print(f"\n################################################")
#     print(f"STARTING SCENARIO: {scenario_name}")
#     print(f"################################################")
    
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     current_run_name = f"{preset['model']}_{preset['task']}_{scenario_name}"
#     model_save_path = f"{current_run_name}_best_model.pt"
#     json_save_path = f"result_{current_run_name}.json"
#     pdf_save_path = f"Confusion_{current_run_name}.pdf"
    
#     # 1. Load Labels
#     data_pd_y = load_data_y(preset["path"]["data_y"], preset["data"]["environment"], preset["data"]["wifi_band"], preset["data"]["num_users"])
    
#     # Apply Subset Ratio
#     subset_ratio = preset["data"]["subset_ratio"]
#     if subset_ratio < 1.0:
#         data_pd_y = data_pd_y.sample(frac=subset_ratio, random_state=42).reset_index(drop=True)
#         print(f"*** DEBUG MODE: Using {subset_ratio*100}% of data ({len(data_pd_y)} samples) ***")
    
#     # 2. Load X
#     data_x = load_data_x(preset["path"]["data_x"], data_pd_y["label"].tolist(), use_rpca=use_rpca)
#     data_y = encode_data_y(data_pd_y, preset["task"])
    
#     # 3. Split
#     train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.2, shuffle=True, random_state=39)
#     train_ds = TensorDataset(torch.from_numpy(train_x), torch.from_numpy(train_y))
#     test_ds = TensorDataset(torch.from_numpy(test_x), torch.from_numpy(test_y))
#     train_loader = DataLoader(train_ds, batch_size=preset["nn"]["batch_size"], shuffle=True)
#     test_loader = DataLoader(test_ds, batch_size=len(test_ds), shuffle=False)
    
#     result = {"accuracy": []}
    
#     for r in range(preset["repeat"]):
#         print(f"--- Repeat {r+1}/{preset['repeat']} ---")
#         torch.random.manual_seed(r + 39)
        
#         model = THAT(train_x[0].shape, train_y[0].reshape(-1).shape).to(device)
#         optimizer = torch.optim.Adam(model.parameters(), lr=preset["nn"]["lr"])
#         loss_fn = torch.nn.BCEWithLogitsLoss()
        
#         best_w = train(model, optimizer, loss_fn, train_loader, test_loader, 
#                        preset["nn"]["threshold"], preset["nn"]["epoch"], device, model_save_path)
        
#         model.load_state_dict(best_w)
#         with torch.no_grad():
#             preds = model(torch.from_numpy(test_x).to(device))
#             preds_reshaped = (torch.sigmoid(preds) > preset["nn"]["threshold"]).float().cpu().numpy().reshape(-1, 9)
#             targets_reshaped = test_y.reshape(-1, 9)
#             acc = accuracy_score(targets_reshaped, preds_reshaped)
#             result["accuracy"].append(acc)
            
#     print(f"Final Accuracy ({scenario_name}): {np.mean(result['accuracy']):.4f}")
#     with open(json_save_path, "w") as f: json.dump(result, f, indent=4)
    
#     print("Generating Confusion Matrix...")
#     model_cm = THAT(test_x[0].shape, test_y[0].reshape(-1).shape).to("cpu")
#     model_cm.load_state_dict(torch.load(model_save_path, map_location="cpu"))
#     cm_loader = DataLoader(TensorDataset(torch.from_numpy(test_x), torch.from_numpy(test_y)), batch_size=32)
#     num_classes = test_y.shape[2] 
#     title = f"Confusion Matrix: {scenario_name} (Acc: {np.mean(result['accuracy']):.2f} - {subset_ratio*100}% Data)"
#     save_multiclass_confusion_matrix(model_cm, cm_loader, "cpu", pdf_save_path, num_classes, title)
    
#     del model, model_cm, train_x, test_x, data_x
#     gc.collect()
#     torch.cuda.empty_cache()
#     print(f"Done with {scenario_name}.\n")

# def run():
#     scenarios = [
#         ("RPCA", True),
#         ("RAW", False)
#     ]
#     for name, rpca_flag in scenarios:
#         run_experiment(name, rpca_flag)

# if __name__ == "__main__":
#     run()