---
Cell 1: Library Imports
---
---

In [None]:
# Cell 1: Library Imports
import os
import argparse
import numpy as np
import pandas as pd
import scipy.io as scio
import time
import torch
torch.cuda.empty_cache()
import torch._dynamo
from torch.utils.data import TensorDataset, DataLoader
# from ptflops import get_model_complexity_info
from sklearn.metrics import classification_report, accuracy_score
from copy import deepcopy
import json
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import torch



---
Cell 2: preset.py
---
---

In [None]:
"""
[file]          preset.py
[description]   default settings of WiFi-based models
"""

preset = {
    # define model
    "model": "ResNet18",  # "ST-RF", "MLP", "LSTM", "CNN-1D", "CNN-2D", "CLSTM", "ABLSTM", "THAT", "bi-LSTM", "ResNet18"
    # define task
    "task": "count",  # "identity", "activity", "location", "count"
    # number of repeated experiments
    "repeat": 1,
    # path of data
    "path": {
        "data_x": "/kaggle/input/wimans/wifi_csi/amp",   # directory of CSI amplitude files 
        "data_y": "/kaggle/input/wimans/annotation.csv", # path of annotation file
        "save": "result_lstm_epoch=80_batchsize=32_envs=empty_room_wifiband=2.4.json"               # path to save results
    },
    # data selection for experiments
    "data": {
        "num_users": ["0", "1", "2", "3", "4", "5"],  # select number(s) of users
        "wifi_band": ["2.4"],                         # select WiFi band(s)
        "environment": ["classroom"],                 # select environment(s) ["classroom"], ["meeting_room"], ["empty_room"]
        "length": 3000,                               # default length of CSI
    },
    # hyperparameters of models
    "nn": {
        "lr": 1e-3,           # learning rate
        "epoch": 300,         # number of epochs
        "batch_size": 64,    # batch size
        "threshold": 0.5,     # threshold to binarize sigmoid outputs
    },
    # encoding of activities and locations
    "encoding": {
        "activity": {  # encoding of different activities
            "nan":      [0, 0, 0, 0, 0, 0, 0, 0, 0],
            "nothing":  [1, 0, 0, 0, 0, 0, 0, 0, 0],
            "walk":     [0, 1, 0, 0, 0, 0, 0, 0, 0],
            "rotation": [0, 0, 1, 0, 0, 0, 0, 0, 0],
            "jump":     [0, 0, 0, 1, 0, 0, 0, 0, 0],
            "wave":     [0, 0, 0, 0, 1, 0, 0, 0, 0],
            "lie_down": [0, 0, 0, 0, 0, 1, 0, 0, 0],
            "pick_up":  [0, 0, 0, 0, 0, 0, 1, 0, 0],
            "sit_down": [0, 0, 0, 0, 0, 0, 0, 1, 0],
            "stand_up": [0, 0, 0, 0, 0, 0, 0, 0, 1],
        },
        "location": {  # encoding of different locations
            "nan":  [0, 0, 0, 0, 0],
            "a":    [1, 0, 0, 0, 0],
            "b":    [0, 1, 0, 0, 0],
            "c":    [0, 0, 1, 0, 0],
            "d":    [0, 0, 0, 1, 0],
            "e":    [0, 0, 0, 0, 1],
        },
    },
}


# Few-shot parameters (manually set)
dest_env = "empty_room"       # Destination environment["classroom"], ["meeting_room"], ["empty_room"]
few_shot_epochs = 100         # Number of epochs for few-shot training
few_shot_num_samples = 5     # Number of samples to use from the destination test data

Confusion_matrix = 1

name_run = "few={},{},{},m={},t={},epoch={},batch={},environment={}".format(dest_env, few_shot_epochs, few_shot_num_samples, preset["model"], preset["task"], preset["nn"]["epoch"], preset["nn"]["batch_size"], preset["data"]["environment"])
print(name_run)

---
Cell 3: load_data.py
---
---

In [None]:
"""
[file]          load_data.py
[description]   load annotation file and CSI amplitude, and encode labels
"""
from sklearn.preprocessing import OneHotEncoder
import numpy as np

# Note: All necessary libraries (os, numpy, pandas, etc.) are imported in Cell 1.
# from preset import preset   --> preset is already defined in Cell 2.

def load_data_y(var_path_data_y,
                var_environment=None, 
                var_wifi_band=None, 
                var_num_users=None):
    """
    Load annotation file (*.csv) as a pandas dataframe and filter by environment, WiFi band, and number of users.
    """
    data_pd_y = pd.read_csv(var_path_data_y, dtype=str)
    if var_environment is not None:
        data_pd_y = data_pd_y[data_pd_y["environment"].isin(var_environment)]
    if var_wifi_band is not None:
        data_pd_y = data_pd_y[data_pd_y["wifi_band"].isin(var_wifi_band)]
    if var_num_users is not None:
        data_pd_y = data_pd_y[data_pd_y["number_of_users"].isin(var_num_users)]
    return data_pd_y

def load_data_x(var_path_data_x, var_label_list):
    """
    Load CSI amplitude (*.npy) files based on a label list.
    """
    var_path_list = [os.path.join(var_path_data_x, var_label + ".npy") for var_label in var_label_list]
    data_x = []
    for var_path in var_path_list:
        data_csi = np.load(var_path)
        var_pad_length = preset["data"]["length"] - data_csi.shape[0]
        data_csi_pad = np.pad(data_csi, ((var_pad_length, 0), (0, 0), (0, 0), (0, 0)))
        data_x.append(data_csi_pad)
    data_x = np.array(data_x)
    return data_x

def encode_data_y(data_pd_y, var_task):
    """
    Encode labels according to specific task.
    """
    if var_task == "identity":
        data_y = encode_identity(data_pd_y)
    elif var_task == "activity":
        data_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
    elif var_task == "location":
        data_y = encode_location(data_pd_y, preset["encoding"]["location"])
    elif var_task == "count":
        data_y = encode_count(data_pd_y, preset["encoding"]["location"])
    return data_y

def encode_identity(data_pd_y):
    """
    Onehot encoding for identity labels.
    """
    data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
                                    "user_3_location", "user_4_location", 
                                    "user_5_location", "user_6_location"]]
    data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
    data_identity_y[data_identity_y != "nan"] = 1
    data_identity_y[data_identity_y == "nan"] = 0
    data_identity_onehot_y = data_identity_y.astype("int8")
    return data_identity_onehot_y



def encode_activity(data_pd_y, var_encoding):
    """
    Onehot encoding for activity labels.
    """
    data_activity_pd_y = data_pd_y[["user_1_activity", "user_2_activity", 
                                    "user_3_activity", "user_4_activity", 
                                    "user_5_activity", "user_6_activity"]]
    data_activity_y = data_activity_pd_y.to_numpy(copy=True).astype(str)
    data_activity_onehot_y = np.array([[var_encoding[var_y] for var_y in var_sample] for var_sample in data_activity_y])
    return data_activity_onehot_y

def encode_location(data_pd_y, var_encoding):
    """
    Onehot encoding for location labels.
    """
    data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
                                    "user_3_location", "user_4_location", 
                                    "user_5_location", "user_6_location"]]
    data_location_y = data_location_pd_y.to_numpy(copy=True).astype(str)
    data_location_onehot_y = np.array([[var_encoding[var_y] for var_y in var_sample] for var_sample in data_location_y])
    return data_location_onehot_y

# Test functions (optional)
def test_load_data_y():
    print(load_data_y(preset["path"]["data_y"], var_environment=["classroom"]).describe())
    print(load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"]).describe())
    print(load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"], var_num_users=["1", "2", "3"]).describe())

def test_load_data_x():
    data_pd_y = load_data_y(preset["path"]["data_y"], var_environment=["meeting_room"], var_wifi_band=["2.4"], var_num_users=None)
    var_label_list = data_pd_y["label"].to_list()
    data_x = load_data_x(preset["path"]["data_x"], var_label_list)
    print(data_x.shape)

def test_encode_identity():
    data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
    data_identity_onehot_y = encode_identity(data_pd_y)
    print(data_identity_onehot_y.shape)
    print(data_identity_onehot_y[2000])

def test_encode_activity():
    data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
    data_activity_onehot_y = encode_activity(data_pd_y, preset["encoding"]["activity"])
    print(data_activity_onehot_y.shape)
    print(data_activity_onehot_y[1560])

def test_encode_location():
    data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
    data_location_onehot_y = encode_location(data_pd_y, preset["encoding"]["location"])
    print(data_location_onehot_y.shape)
    print(data_location_onehot_y[1560])

def encode_count(data_pd_y, var_encoding):
    """
    Onehot encoding for identity labels.
    """
    data_location_pd_y = data_pd_y[["user_1_location", "user_2_location", 
                                    "user_3_location", "user_4_location", 
                                    "user_5_location", "user_6_location"]]
    data_identity_y = data_location_pd_y.to_numpy(copy=True).astype(str)
    data_identity_y[data_identity_y != "nan"] = 1
    data_identity_y[data_identity_y == "nan"] = 0
    data_identity_onehot_y = data_identity_y.astype("int8")
    print("data_identity_onehot_y",data_identity_onehot_y.shape)
    count_data = np.sum(data_identity_onehot_y, axis=1)
    print("count_data",count_data.shape)
    count_data = count_data.reshape(-1, 1)  # shape = (11286, 1)
    encoder = OneHotEncoder(sparse=False)  
    count_data_onehot = encoder.fit_transform(count_data)
    print(count_data_onehot.shape)  
    count_data_onehot = count_data_onehot.astype("int8")

    return count_data_onehot


def test_encode_count():
    data_pd_y = pd.read_csv(preset["path"]["data_y"], dtype=str)
    data_count_onehot_y = encode_count(data_pd_y)
    print(data_count_onehot_y.shape)
    print(data_count_onehot_y[20])

# if __name__ == "__main__":
#     test_encode_count()
#     test_load_data_y()
#     test_load_data_x()
#     test_encode_identity()
#     test_encode_activity()
#     test_encode_location()


---
Cell 4: preprocess.py
---
---

In [None]:
"""
[file]          preprocess.py
[description]   preprocess WiFi CSI data
"""

# All necessary libraries are already imported in Cell 1.

def mat_to_amp(data_mat):
    """
    Calculate amplitude of raw WiFi CSI data.
    """
    var_length = data_mat["trace"].shape[0]
    data_csi_amp = [abs(data_mat["trace"][var_t][0][0][0][-1]) for var_t in range(var_length)]
    data_csi_amp = np.array(data_csi_amp, dtype=np.float32)
    return data_csi_amp

def extract_csi_amp(var_dir_mat, var_dir_amp):
    """
    Read raw WiFi CSI (*.mat) files, calculate CSI amplitude, and save as (*.npy).
    """
    var_path_mat = os.listdir(var_dir_mat)
    for var_c, var_path in enumerate(var_path_mat):
        data_mat = scio.loadmat(os.path.join(var_dir_mat, var_path))
        data_csi_amp = mat_to_amp(data_mat)
        # print(var_c, data_csi_amp.shape)
        var_path_save = os.path.join(var_dir_amp, var_path.replace(".mat", ".npy"))
        with open(var_path_save, "wb") as var_file:
            np.save(var_file, data_csi_amp)

def parse_args():
    """
    Parse arguments from input.
    """
    var_args = argparse.ArgumentParser()
    var_args.add_argument("--dir_mat", default="/kaggle/input/wimans/wifi_csi/mat", type=str)
    var_args.add_argument("--dir_amp", default="/kaggle/input/wimans/wifi_csi/amp", type=str)
    return var_args.parse_args()

# if __name__ == "__main__":
#     var_args = parse_args()
#     extract_csi_amp(var_dir_mat=var_args.dir_mat, var_dir_amp=var_args.dir_amp)


---
Cell 5: that.py (WiFi-based Model THAT)
---
---

In [None]:
"""
[file]          that.py
[description]   implement and evaluate WiFi-based model THAT
                https://github.com/windofshadow/THAT
"""

# All necessary libraries are imported in Cell 1.
# from train import train   --> Defined in Cell 6.
# from preset import preset --> Defined in Cell 2.

class Gaussian_Position(torch.nn.Module):
    def __init__(self, var_dim_feature, var_dim_time, var_num_gaussian=10):
        super(Gaussian_Position, self).__init__()
        var_embedding = torch.zeros([var_num_gaussian, var_dim_feature], dtype=torch.float)
        self.var_embedding = torch.nn.Parameter(var_embedding, requires_grad=True)
        torch.nn.init.xavier_uniform_(self.var_embedding)
        var_position = torch.arange(0.0, var_dim_time).unsqueeze(1).repeat(1, var_num_gaussian)
        self.var_position = torch.nn.Parameter(var_position, requires_grad=False)
        var_mu = torch.arange(0.0, var_dim_time, var_dim_time/var_num_gaussian).unsqueeze(0)
        self.var_mu = torch.nn.Parameter(var_mu, requires_grad=True)
        var_sigma = torch.tensor([50.0] * var_num_gaussian).unsqueeze(0)
        self.var_sigma = torch.nn.Parameter(var_sigma, requires_grad=True)

    def calculate_pdf(self, var_position, var_mu, var_sigma):
        var_pdf = var_position - var_mu
        var_pdf = - var_pdf * var_pdf
        var_pdf = var_pdf / var_sigma / var_sigma / 2
        var_pdf = var_pdf - torch.log(var_sigma)
        return var_pdf

    def forward(self, var_input):
        var_pdf = self.calculate_pdf(self.var_position, self.var_mu, self.var_sigma)
        var_pdf = torch.softmax(var_pdf, dim=-1)
        var_position_encoding = torch.matmul(var_pdf, self.var_embedding)
        var_output = var_input + var_position_encoding.unsqueeze(0)
        return var_output

class Encoder(torch.nn.Module):
    def __init__(self, var_dim_feature, var_num_head=10, var_size_cnn=[1, 3, 5]):
        super(Encoder, self).__init__()
        self.layer_norm_0 = torch.nn.LayerNorm(var_dim_feature, eps=1e-6)
        self.layer_attention = torch.nn.MultiheadAttention(var_dim_feature, var_num_head, batch_first=True)
        self.layer_dropout_0 = torch.nn.Dropout(0.1)
        self.layer_norm_1 = torch.nn.LayerNorm(var_dim_feature, 1e-6)
        layer_cnn = []
        for var_size in var_size_cnn:
            layer = torch.nn.Sequential(
                torch.nn.Conv1d(var_dim_feature, var_dim_feature, var_size, padding="same"),
                torch.nn.BatchNorm1d(var_dim_feature),
                torch.nn.Dropout(0.1),
                torch.nn.LeakyReLU()
            )
            layer_cnn.append(layer)
        self.layer_cnn = torch.nn.ModuleList(layer_cnn)
        self.layer_dropout_1 = torch.nn.Dropout(0.1)

    def forward(self, var_input):
        var_t = var_input
        var_t = self.layer_norm_0(var_t)
        var_t, _ = self.layer_attention(var_t, var_t, var_t)
        var_t = self.layer_dropout_0(var_t)
        var_t = var_t + var_input
        var_s = self.layer_norm_1(var_t)
        var_s = torch.permute(var_s, (0, 2, 1))
        var_c = torch.stack([layer(var_s) for layer in self.layer_cnn], dim=0)
        var_s = torch.sum(var_c, dim=0) / len(self.layer_cnn)
        var_s = self.layer_dropout_1(var_s)
        var_s = torch.permute(var_s, (0, 2, 1))
        var_output = var_s + var_t
        return var_output

class THAT(torch.nn.Module):
    def __init__(self, var_x_shape, var_y_shape):
        super(THAT, self).__init__()
        var_dim_feature = var_x_shape[-1]
        var_dim_time = var_x_shape[-2]
        var_dim_output = var_y_shape[-1]
        # Left branch
        self.layer_left_pooling = torch.nn.AvgPool1d(kernel_size=20, stride=20)
        self.layer_left_gaussian = Gaussian_Position(var_dim_feature, var_dim_time // 20)
        var_num_left = 4
        var_dim_left = var_dim_feature
        self.layer_left_encoder = torch.nn.ModuleList([
            Encoder(var_dim_feature=var_dim_left, var_num_head=10, var_size_cnn=[1, 3, 5])
            for _ in range(var_num_left)
        ])
        self.layer_left_norm = torch.nn.LayerNorm(var_dim_left, eps=1e-6)
        self.layer_left_cnn_0 = torch.nn.Conv1d(in_channels=var_dim_left, out_channels=128, kernel_size=8)
        self.layer_left_cnn_1 = torch.nn.Conv1d(in_channels=var_dim_left, out_channels=128, kernel_size=16)
        self.layer_left_dropout = torch.nn.Dropout(0.5)
        # Right branch
        self.layer_right_pooling = torch.nn.AvgPool1d(kernel_size=20, stride=20)
        var_num_right = 1
        var_dim_right = var_dim_time // 20
        self.layer_right_encoder = torch.nn.ModuleList([
            Encoder(var_dim_feature=var_dim_right, var_num_head=10, var_size_cnn=[1, 2, 3])
            for _ in range(var_num_right)
        ])
        self.layer_right_norm = torch.nn.LayerNorm(var_dim_right, eps=1e-6)
        self.layer_right_cnn_0 = torch.nn.Conv1d(in_channels=var_dim_right, out_channels=16, kernel_size=2)
        self.layer_right_cnn_1 = torch.nn.Conv1d(in_channels=var_dim_right, out_channels=16, kernel_size=4)
        self.layer_right_dropout = torch.nn.Dropout(0.5)
        self.layer_leakyrelu = torch.nn.LeakyReLU()
        self.layer_output = torch.nn.Linear(256 + 32, var_dim_output)

    def forward(self, var_input):
        var_t = var_input  # shape: (batch_size, time_steps, features)
        # Left branch
        var_left = torch.permute(var_t, (0, 2, 1))
        var_left = self.layer_left_pooling(var_left)
        var_left = torch.permute(var_left, (0, 2, 1))
        var_left = self.layer_left_gaussian(var_left)
        for layer in self.layer_left_encoder:
            var_left = layer(var_left)
        var_left = self.layer_left_norm(var_left)
        var_left = torch.permute(var_left, (0, 2, 1))
        var_left_0 = self.layer_leakyrelu(self.layer_left_cnn_0(var_left))
        var_left_1 = self.layer_leakyrelu(self.layer_left_cnn_1(var_left))
        var_left_0 = torch.sum(var_left_0, dim=-1)
        var_left_1 = torch.sum(var_left_1, dim=-1)
        var_left = torch.concat([var_left_0, var_left_1], dim=-1)
        var_left = self.layer_left_dropout(var_left)
        # Right branch
        var_right = torch.permute(var_t, (0, 2, 1))
        var_right = self.layer_right_pooling(var_right)
        for layer in self.layer_right_encoder:
            var_right = layer(var_right)
        var_right = self.layer_right_norm(var_right)
        var_right = torch.permute(var_right, (0, 2, 1))
        var_right_0 = self.layer_leakyrelu(self.layer_right_cnn_0(var_right))
        var_right_1 = self.layer_leakyrelu(self.layer_right_cnn_1(var_right))
        var_right_0 = torch.sum(var_right_0, dim=-1)
        var_right_1 = torch.sum(var_right_1, dim=-1)
        var_right = torch.concat([var_right_0, var_right_1], dim=-1)
        var_right = self.layer_right_dropout(var_right)
        # Concatenate branches
        var_t = torch.concat([var_left, var_right], dim=-1)
        var_output = self.layer_output(var_t)
        return var_output

def run_that(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat=10):
    """
    Run WiFi-based model THAT.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_train_x = data_train_x.reshape(data_train_x.shape[0], data_train_x.shape[1], -1)
    data_test_x = data_test_x.reshape(data_test_x.shape[0], data_test_x.shape[1], -1)
    var_x_shape, var_y_shape = data_train_x[0].shape, data_train_y[0].reshape(-1).shape
    data_train_set = TensorDataset(torch.from_numpy(data_train_x), torch.from_numpy(data_train_y))
    data_test_set = TensorDataset(torch.from_numpy(data_test_x), torch.from_numpy(data_test_y))
    
    result = {}
    result_accuracy = []
    result_time_train = []
    result_time_test = []
    
    # var_macs, var_params = get_model_complexity_info(THAT(var_x_shape, var_y_shape), var_x_shape, as_strings=False)
    # print("Parameters:", var_params, "- FLOPs:", var_macs * 2)
    
    for var_r in range(var_repeat):
        print("Repeat", var_r)
        torch.random.manual_seed(var_r + 39)
        model_that = THAT(var_x_shape, var_y_shape).to(device)
        optimizer = torch.optim.Adam(model_that.parameters(), lr=preset["nn"]["lr"], weight_decay=0)
        loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4] * var_y_shape[-1]).to(device))
        var_time_0 = time.time()
        
        # Train
        var_best_weight = train(model=model_that, optimizer=optimizer, loss=loss, 
                                  data_train_set=data_train_set, data_test_set=data_test_set,
                                  var_threshold=preset["nn"]["threshold"],
                                  var_batch_size=preset["nn"]["batch_size"],
                                  var_epochs=preset["nn"]["epoch"],
                                  device=device)
        var_time_1 = time.time()
        
        # Test
        model_that.load_state_dict(var_best_weight)
        with torch.no_grad():
            predict_test_y = model_that(torch.from_numpy(data_test_x).to(device))
        predict_test_y = (torch.sigmoid(predict_test_y) > preset["nn"]["threshold"]).float()
        predict_test_y = predict_test_y.detach().cpu().numpy()
        var_time_2 = time.time()
        
        # Evaluate
        data_test_y_c = data_test_y.reshape(-1, data_test_y.shape[-1])
        predict_test_y_c = predict_test_y.reshape(-1, data_test_y.shape[-1])
        result_acc = accuracy_score(data_test_y_c.astype(int), predict_test_y_c.astype(int))
        result_dict = classification_report(data_test_y_c, predict_test_y_c, digits=6, zero_division=0, output_dict=True)
        result["repeat_" + str(var_r)] = result_dict
        result_accuracy.append(result_acc)
        result_time_train.append(var_time_1 - var_time_0)
        result_time_test.append(var_time_2 - var_time_1)
        print("repeat_" + str(var_r), result_accuracy)
        print(result)
    
    result["accuracy"] = {"avg": np.mean(result_accuracy), "std": np.std(result_accuracy)}
    result["time_train"] = {"avg": np.mean(result_time_train), "std": np.std(result_time_train)}
    result["time_test"] = {"avg": np.mean(result_time_test), "std": np.std(result_time_test)}
    # result["complexity"] = {"parameter": var_params, "flops": var_macs * 2}
    return result


---
Cell6: for LSTM Model (lstm.py)
---
---

In [None]:
"""
[file]          lstm.py
[description]   implement and evaluate WiFi-based model LSTM
"""
import torch._dynamo
# torch._dynamo.config.suppress_errors = True
import time
import torch
torch.cuda.empty_cache()
import numpy as np
from torch.utils.data import TensorDataset
# from ptflops import get_model_complexity_info
from sklearn.metrics import classification_report, accuracy_score
# from preset import preset

class LSTMM(torch.nn.Module):
    def __init__(self, var_x_shape, var_y_shape):
        super(LSTMM, self).__init__()
        var_dim_input = var_x_shape[-1]
        var_dim_output = var_y_shape[-1]
        self.layer_norm = torch.nn.BatchNorm1d(var_dim_input)
        self.layer_pooling = torch.nn.AvgPool1d(10, 10)
        self.layer_lstm = torch.nn.LSTM(input_size=var_dim_input,
                                        hidden_size=512, 
                                        batch_first=True)
        self.layer_linear = torch.nn.Linear(512, var_dim_output)

    def forward(self, var_input):
        var_t = var_input
        var_t = torch.permute(var_t, (0, 2, 1))
        var_t = self.layer_norm(var_t)
        var_t = self.layer_pooling(var_t)
        var_t = torch.permute(var_t, (0, 2, 1))
        var_t, _ = self.layer_lstm(var_t)
        var_t = var_t[:, -1, :]
        var_t = self.layer_linear(var_t)
        var_output = var_t
        return var_output

def run_lstm(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat=10):
    """
    Run WiFi-based model LSTM.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    data_train_x = data_train_x.reshape(data_train_x.shape[0], data_train_x.shape[1], -1)
    data_test_x = data_test_x.reshape(data_test_x.shape[0], data_test_x.shape[1], -1)
    
    var_x_shape, var_y_shape = data_train_x[0].shape, data_train_y[0].reshape(-1).shape
    
    data_train_set = TensorDataset(torch.from_numpy(data_train_x), torch.from_numpy(data_train_y))
    data_test_set = TensorDataset(torch.from_numpy(data_test_x), torch.from_numpy(data_test_y))
    
    result = {}
    result_accuracy = []
    result_time_train = []
    result_time_test = []
    
    # var_macs, var_params = get_model_complexity_info(LSTMM(var_x_shape, var_y_shape), var_x_shape, as_strings=False)
    # print("Parameters:", var_params, "- FLOPs:", var_macs * 2)
    
    for var_r in range(var_repeat):
        print("Repeat", var_r)
        torch.random.manual_seed(var_r + 39)
        model_lstm = LSTMM(var_x_shape, var_y_shape).to(device)
        optimizer = torch.optim.Adam(model_lstm.parameters(), 
                                     lr=preset["nn"]["lr"],
                                     weight_decay=0)
        loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([6] * var_y_shape[-1]).to(device))
        var_time_0 = time.time()
        
        # Train
        var_best_weight = train(model=model_lstm, 
                                optimizer=optimizer, 
                                loss=loss, 
                                data_train_set=data_train_set,
                                data_test_set=data_test_set,
                                var_threshold=preset["nn"]["threshold"],
                                var_batch_size=preset["nn"]["batch_size"],
                                var_epochs=preset["nn"]["epoch"],
                                device=device)
        var_time_1 = time.time()
        
        # Test
        model_lstm.load_state_dict(var_best_weight)
        with torch.no_grad():
            predict_test_y = model_lstm(torch.from_numpy(data_test_x).to(device))
            
        predict_test_y = (torch.sigmoid(predict_test_y) > preset["nn"]["threshold"]).float()
        predict_test_y = predict_test_y.detach().cpu().numpy()
        var_time_2 = time.time()
        
        # Evaluate
        data_test_y_c = data_test_y.reshape(-1, data_test_y.shape[-1])
        predict_test_y_c = predict_test_y.reshape(-1, data_test_y.shape[-1])
        result_acc = accuracy_score(data_test_y_c.astype(int), predict_test_y_c.astype(int))
        result_dict = classification_report(data_test_y_c, predict_test_y_c, digits=6, zero_division=0, output_dict=True)
        result["repeat_" + str(var_r)] = result_dict
        result_accuracy.append(result_acc)
        result_time_train.append(var_time_1 - var_time_0)
        result_time_test.append(var_time_2 - var_time_1)
        print("repeat_" + str(var_r), result_accuracy)
        print(result)
    
    result["accuracy"] = {"avg": np.mean(result_accuracy), "std": np.std(result_accuracy)}
    result["time_train"] = {"avg": np.mean(result_time_train), "std": np.std(result_time_train)}
    result["time_test"] = {"avg": np.mean(result_time_test), "std": np.std(result_time_test)}
    # result["complexity"] = {"parameter": var_params, "flops": var_macs * 2}
    
    return result


---
Cell7: for RESNET18 Model
---
---

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import torch._dynamo
torch._dynamo.config.suppress_errors = True
import time
import torch
torch.cuda.empty_cache()
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report
import torchvision.models as models
from copy import deepcopy

torch.set_float32_matmul_precision("high")
torch._dynamo.config.cache_size_limit = 65536

# فرض می‌کنیم preset قبلاً تعریف شده باشه
# preset = { "nn": {"lr": 1e-3, "epoch": 10, "batch_size": 4, "threshold": 0.5}, ... }

class ResNet18Model(torch.nn.Module):
    def __init__(self, var_x_shape, var_y_shape):
        super(ResNet18Model, self).__init__()
        model_resnet = models.resnet18(weights=None)
        model_resnet.conv1 = torch.nn.Conv2d(1, 64, 7, 3, 2, bias=False)
        in_features_fc = model_resnet.fc.in_features  # معمولاً 512
        out_features_fc = var_y_shape[-1]
        model_resnet.fc = torch.nn.Linear(in_features_fc, out_features_fc)
        self.resnet = model_resnet

    def forward(self, var_input):
        var_input = var_input.reshape(var_input.size(0), 1, 3000, 270)
        return self.resnet(var_input)

def run_resnet(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    var_x_shape = data_train_x[0].shape
    var_y_shape = data_train_y[0].reshape(-1).shape

    # تغییر شکل داده‌ها روی CPU
    data_train_x = data_train_x.reshape(data_train_x.shape[0], 1, data_train_x.shape[1],
                                        data_train_x.shape[2]*data_train_x.shape[3]*data_train_x.shape[4])
    data_test_x  = data_test_x.reshape(data_test_x.shape[0], 1, data_test_x.shape[1],
                                       data_test_x.shape[2]*data_test_x.shape[3]*data_test_x.shape[4])
    
    # دیتاست‌ها روی CPU
    data_train_set = TensorDataset(torch.from_numpy(data_train_x).float(),
                                   torch.from_numpy(data_train_y).float())
    data_test_set  = TensorDataset(torch.from_numpy(data_test_x).float(),
                                   torch.from_numpy(data_test_y).float())
    
    result = {}
    result_accuracy = []
    result_time_train = []
    result_time_test = []
    
    for var_r in range(var_repeat):
        print("Repeat", var_r)
        torch.random.manual_seed(var_r + 39)
        
        # ساخت مدل و انتقال به GPU
        model_resnet = ResNet18Model(var_x_shape, var_y_shape).to(device)
        optimizer = torch.optim.Adam(model_resnet.parameters(), lr=preset["nn"]["lr"], weight_decay=0)
        loss_func = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([6] * var_y_shape[-1]).to(device))
        
        # تابع آموزش داخلی؛ دیتا روی CPU باقی می‌مونه و فقط هنگام محاسبه batch به GPU میره
        def train_inner():
            train_loader = DataLoader(data_train_set, preset["nn"]["batch_size"], shuffle=True, pin_memory=False)
            test_loader = DataLoader(data_test_set, preset["nn"]["batch_size"], shuffle=False, pin_memory=False)
            best_accuracy = 0
            best_weight = None
            
            for epoch in range(preset["nn"]["epoch"]):
                t0 = time.time()
                model_resnet.train()
                # متغیرهای مربوط به آخرین batch آموزش
                last_train_loss = None
                last_train_acc = None
                for batch in train_loader:
                    batch_x, batch_y = batch
                    batch_x = batch_x.to(device)
                    batch_y = batch_y.to(device)
                    outputs = model_resnet(batch_x)
                    loss_val = loss_func(outputs, batch_y.reshape(batch_y.shape[0], -1).float())
                    optimizer.zero_grad()
                    loss_val.backward()
                    optimizer.step()
                    last_train_loss = loss_val.item()
                    # محاسبه دقت آخرین batch آموزش
                    train_preds = (torch.sigmoid(outputs) > preset["nn"]["threshold"]).float()
                    last_train_acc = accuracy_score(batch_y.reshape(batch_y.shape[0], -1).detach().cpu().numpy().astype(int),
                                                    train_preds.detach().cpu().numpy().astype(int))
                
                # ارزیابی روی دیتاست تست به صورت batch به batch
                model_resnet.eval()
                all_preds = []
                all_labels = []
                test_loss_val = None
                with torch.no_grad():
                    for t_batch in test_loader:
                        t_x, t_y = t_batch
                        t_x = t_x.to(device)
                        outputs_test = model_resnet(t_x)
                        outputs_test = (torch.sigmoid(outputs_test) > preset["nn"]["threshold"]).float()
                        all_preds.append(outputs_test.detach().cpu().numpy())
                        all_labels.append(t_y.cpu().numpy())  # اینجا تغییر دادیم
                preds_cat = np.vstack(all_preds)
                labels_cat = np.vstack(all_labels)
                print("preds_cat",preds_cat.shape)
                # تبدیل به شکل (n, 6, 5)
                
                # preds_cat = preds_cat.reshape(-1, 6, 5)
                # labels_cat = labels_cat.reshape(-1, 6, 5)

                preds_cat = preds_cat.reshape(-1, 6)
                labels_cat = labels_cat.reshape(-1, 6)
                
                # برای محاسبه دقت، مسطح می‌کنیم
                test_acc = accuracy_score(labels_cat.reshape(labels_cat.shape[0], -1).astype(int),
                                          preds_cat.reshape(preds_cat.shape[0], -1).astype(int))
                epoch_time = time.time() - t0
                print(f"Epoch {epoch}/{preset['nn']['epoch']} - "
                      f"Train Loss: {(last_train_loss if last_train_loss is not None else 0.0):.6f}, "
                      f"Train Acc: {(last_train_acc if last_train_acc is not None else 0.0):.6f}, "
                      f"Test Loss: {(test_loss_val if test_loss_val is not None else 0.0):.6f}, "
                      f"Test Acc: {(test_acc if test_acc is not None else 0.0):.6f} - "
                      f"Time: {epoch_time:.4f}s")

                if test_acc > best_accuracy:
                    best_accuracy = test_acc
                    print('-----***-----')
                    print(best_accuracy)
                    best_weight = deepcopy(model_resnet.state_dict())
            return best_weight
        
        t0_run = time.time()
        best_weight = train_inner()
        t1_run = time.time()
        
        torch.save(model_resnet.state_dict(), f"{name_run}_model_final.pt")
        model_resnet.load_state_dict(best_weight)
        torch.save(model_resnet.state_dict(), f"{name_run}_best_model.pt")

        # bad age niaz bod load koni
        # model_resnet = ResNet18Model(var_x_shape, var_y_shape).to(device)
        # model_resnet.load_state_dict(torch.load("resnet_model_repeat0.pt"))
        # model_resnet.eval()

        
        # ارزیابی نهایی مدل روی دیتاست تست (استفاده از batchهای کوچک)
        model_resnet.eval()
        all_preds = []
        test_loader_final = DataLoader(data_test_set, preset["nn"]["batch_size"], shuffle=False, pin_memory=False)
        with torch.no_grad():
            for batch in test_loader_final:
                batch_x, _ = batch
                batch_x = batch_x.to(device)
                all_preds.append(model_resnet(batch_x))
        preds_all = torch.cat(all_preds, dim=0)
        preds_final = (torch.sigmoid(preds_all) > preset["nn"]["threshold"]).float().detach().cpu().numpy()
        t2_run = time.time()
        
        data_test_y_np = data_test_y.reshape(-1, data_test_y.shape[-1])
        preds_final = preds_final.reshape(-1, data_test_y.shape[-1])
        acc_final = accuracy_score(data_test_y_np.astype(int), preds_final.astype(int))
        result[f"repeat_{var_r}"] = {"accuracy": acc_final}
        result_accuracy.append(acc_final)
        result_time_train.append(t1_run - t0_run)
        result_time_test.append(t2_run - t1_run)
        print("Repeat", var_r, "Final Test Accuracy:", acc_final)
    
    result["accuracy"] = {"avg": np.mean(result_accuracy), "std": np.std(result_accuracy)}
    result["time_train"] = {"avg": np.mean(result_time_train), "std": np.std(result_time_train)}
    result["time_test"] = {"avg": np.mean(result_time_test), "std": np.std(result_time_test)}
    return result


---
Cell8: for bi-LSTM Model
---
---

In [None]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import time
import torch
torch.cuda.empty_cache()
import numpy as np
from torch.utils.data import TensorDataset
from sklearn.metrics import classification_report, accuracy_score

class BiLSTMM(torch.nn.Module):
    def __init__(self, var_x_shape, var_y_shape):
        super(BiLSTMM, self).__init__()
        var_dim_input = var_x_shape[-1]
        var_dim_output = var_y_shape[-1]
        self.layer_norm = torch.nn.BatchNorm1d(var_dim_input)
        self.layer_pooling = torch.nn.AvgPool1d(10, 10)
        # استفاده از LSTM دو جهته (bidirectional)
        self.layer_lstm = torch.nn.LSTM(input_size=var_dim_input,
                                        hidden_size=512, 
                                        batch_first=True,
                                        bidirectional=True)
        # تغییر ابعاد ورودی لایه linear به 512*2
        self.layer_linear = torch.nn.Linear(512 * 2, var_dim_output)

    def forward(self, var_input):
        var_t = var_input
        var_t = torch.permute(var_t, (0, 2, 1))
        var_t = self.layer_norm(var_t)
        var_t = self.layer_pooling(var_t)
        var_t = torch.permute(var_t, (0, 2, 1))
        var_t, _ = self.layer_lstm(var_t)
        var_t = var_t[:, -1, :]
        var_t = self.layer_linear(var_t)
        var_output = var_t
        return var_output

def run_bilstm(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat=10):
    """
    Run WiFi-based model bi-LSTM.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    data_train_x = data_train_x.reshape(data_train_x.shape[0], data_train_x.shape[1], -1)
    data_test_x = data_test_x.reshape(data_test_x.shape[0], data_test_x.shape[1], -1)
    
    var_x_shape, var_y_shape = data_train_x[0].shape, data_train_y[0].reshape(-1).shape
    
    data_train_set = TensorDataset(torch.from_numpy(data_train_x), torch.from_numpy(data_train_y))
    data_test_set = TensorDataset(torch.from_numpy(data_test_x), torch.from_numpy(data_test_y))
    
    result = {}
    result_accuracy = []
    result_time_train = []
    result_time_test = []
    
    for var_r in range(var_repeat):
        print("Repeat", var_r)
        torch.random.manual_seed(var_r + 39)
        model_bilstm = BiLSTMM(var_x_shape, var_y_shape).to(device)
        optimizer = torch.optim.Adam(model_bilstm.parameters(), 
                                     lr=preset["nn"]["lr"],
                                     weight_decay=0)
        loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([6] * var_y_shape[-1]).to(device))
        var_time_0 = time.time()
        
        var_best_weight = train(model=model_bilstm, 
                                optimizer=optimizer, 
                                loss=loss, 
                                data_train_set=data_train_set,
                                data_test_set=data_test_set,
                                var_threshold=preset["nn"]["threshold"],
                                var_batch_size=preset["nn"]["batch_size"],
                                var_epochs=preset["nn"]["epoch"],
                                device=device)
        var_time_1 = time.time()
        
        model_bilstm.load_state_dict(var_best_weight)


        
        with torch.no_grad():
            predict_test_y = model_bilstm(torch.from_numpy(data_test_x).to(device))
            
        predict_test_y = (torch.sigmoid(predict_test_y) > preset["nn"]["threshold"]).float()
        predict_test_y = predict_test_y.detach().cpu().numpy()
        var_time_2 = time.time()
        
        data_test_y_c = data_test_y.reshape(-1, data_test_y.shape[-1])
        predict_test_y_c = predict_test_y.reshape(-1, data_test_y.shape[-1])
        result_acc = accuracy_score(data_test_y_c.astype(int), predict_test_y_c.astype(int))
        result_dict = classification_report(data_test_y_c, predict_test_y_c, digits=6, zero_division=0, output_dict=True)
        result["repeat_" + str(var_r)] = result_dict
        result_accuracy.append(result_acc)
        result_time_train.append(var_time_1 - var_time_0)
        result_time_test.append(var_time_2 - var_time_1)
        print("repeat_" + str(var_r), result_accuracy)
        print(result)
    
    result["accuracy"] = {"avg": np.mean(result_accuracy), "std": np.std(result_accuracy)}
    result["time_train"] = {"avg": np.mean(result_time_train), "std": np.std(result_time_train)}
    result["time_test"] = {"avg": np.mean(result_time_test), "std": np.std(result_time_test)}
    
    return result


---
Cell 9: train.py
---
---

In [None]:
"""
[file]          train.py
[description]   function to train WiFi-based models
"""

# All necessary libraries are imported in Cell 1.

torch.set_float32_matmul_precision("high")
torch._dynamo.config.cache_size_limit = 65536

def train(model, optimizer, loss, data_train_set, data_test_set, var_threshold, var_batch_size, var_epochs, device):
    """
    Generic training function for WiFi-based models.
    """
    # دیتا رو روی CPU نگه می‌داریم (pin_memory=False)
    data_train_loader = DataLoader(data_train_set, var_batch_size, shuffle=True, pin_memory=False)
    data_test_loader = DataLoader(data_test_set, batch_size=len(data_test_set), shuffle=False, pin_memory=False)

    var_best_accuracy = 0
    var_best_weight = None
    
    var_best_accuracy = -1.0
    var_best_weight   = deepcopy(model.state_dict())
    
    
    
    for var_epoch in range(var_epochs):
        var_time_e0 = time.time()
        model.train()
        for data_batch in data_train_loader:
            data_batch_x, data_batch_y = data_batch
            # انتقال موقتی داده به GPU فقط برای forward pass
            data_batch_x = data_batch_x.to(device)
            data_batch_y = data_batch_y.to(device)
            predict_train_y = model(data_batch_x)
            var_loss_train = loss(predict_train_y, data_batch_y.reshape(data_batch_y.shape[0], -1).float())
            optimizer.zero_grad()
            var_loss_train.backward()
            optimizer.step()
        
        # محاسبه دقت روی آخرین batch و انتقال نتایج به CPU
        predict_train_y = (torch.sigmoid(predict_train_y) > var_threshold).float()
        data_batch_y = data_batch_y.detach().cpu().numpy()
        predict_train_y = predict_train_y.detach().cpu().numpy()
        
        predict_train_y = predict_train_y.reshape(-1, data_batch_y.shape[-1])
        data_batch_y = data_batch_y.reshape(-1, data_batch_y.shape[-1])
        var_accuracy_train = accuracy_score(data_batch_y.astype(int), predict_train_y.astype(int))
        
        model.eval()
        with torch.no_grad():
            data_test_x, data_test_y = next(iter(data_test_loader))
            # انتقال موقتی دیتا تست به GPU برای محاسبات
            data_test_x = data_test_x.to(device)
            data_test_y = data_test_y.to(device)
            
            predict_test_y = model(data_test_x)
            var_loss_test = loss(predict_test_y, data_test_y.reshape(data_test_y.shape[0], -1).float())
            
            predict_test_y = (torch.sigmoid(predict_test_y) > var_threshold).float()
            
            # انتقال نتایج به CPU برای ارزیابی
            data_test_y = data_test_y.detach().cpu().numpy()
            predict_test_y = predict_test_y.detach().cpu().numpy()
            
            predict_test_y = predict_test_y.reshape(-1, data_test_y.shape[-1])
            data_test_y = data_test_y.reshape(-1, data_test_y.shape[-1])
            var_accuracy_test = accuracy_score(data_test_y.astype(int), predict_test_y.astype(int))
        
        print(f"Epoch {var_epoch}/{var_epochs}",
              "- %.6fs"%(time.time() - var_time_e0),
              "- Loss %.6f"%var_loss_train.cpu(),
              "- Accuracy %.6f"%var_accuracy_train,
              "- Test Loss %.6f"%var_loss_test.cpu(),
              "- Test Accuracy %.6f"%var_accuracy_test)
            
        if var_accuracy_test > var_best_accuracy:
            var_best_accuracy = var_accuracy_test
            print('-----***-----')
            print(var_best_accuracy)
            var_best_weight = deepcopy(model.state_dict())

    torch.save(model.state_dict(), f"{name_run}_model_final.pt")
    torch.save(var_best_weight, f"{name_run}_best_model.pt")

    
    return var_best_weight



# === importsِ لازم را یک‌بار بالای فایل اضافه کن ===
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

# ---------- تابع کمکی ----------
def save_confusion_matrix(model, data_loader, threshold, device, pdf_path):
    """
    Runs the model on `data_loader`, builds a confusion matrix and writes it to `pdf_path`.
    """
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for xb, yb in data_loader:
            xb = xb.to(device)
            logits = model(xb)

            preds = (torch.sigmoid(logits) > threshold).float().cpu().numpy().ravel()
            yb    = yb.cpu().numpy().ravel()

            y_true.extend(yb)
            y_pred.extend(preds)

    cm  = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots()
    ConfusionMatrixDisplay(cm).plot(ax=ax)
    ax.set_title("Confusion Matrix – Test")

    with PdfPages(pdf_path) as pdf:
        pdf.savefig(fig)
    plt.close(fig)
# ---------------------------------


---
Cell 10: run.py
---
---

In [None]:
# """
# [file]          run.py
# [description]   run WiFi-based models
# """

# # All necessary libraries are imported in Cell 1.
# # from model import *   --> Make sure the proper model functions (e.g., run_that) are defined in the respective cells.
# # from preset import preset --> Already defined in Cell 2.
# # from load_data import load_data_x, load_data_y, encode_data_y --> Defined in Cell 3.

# def parse_args():
#     var_args = argparse.ArgumentParser()
#     var_args.add_argument("--model", default=preset["model"], type=str)
#     var_args.add_argument("--task", default=preset["task"], type=str)
#     var_args.add_argument("--repeat", default=preset["repeat"], type=int)
#     args, _ = var_args.parse_known_args()  # پارامترهای اضافی نادیده گرفته می‌شن
#     return args

# def run():
#     """
#     Run WiFi-based models.
#     """
#     var_args = parse_args()
#     var_task = var_args.task
#     var_model = var_args.model
#     var_repeat = var_args.repeat
    
#     data_pd_y = load_data_y(preset["path"]["data_y"],
#                             var_environment=preset["data"]["environment"], 
#                             var_wifi_band=preset["data"]["wifi_band"], 
#                             var_num_users=preset["data"]["num_users"])
#     var_label_list = data_pd_y["label"].to_list()
#     data_x = load_data_x(preset["path"]["data_x"], var_label_list)
#     data_y = encode_data_y(data_pd_y, var_task)
#     print("========================1111111==========")
#     print(data_y[0])
#     print(data_y[5])
#     print(data_y[10])
#     print(data_y.shape)
#     print("==========================11111111========")
#     data_train_x, data_test_x, data_train_y, data_test_y = train_test_split(data_x, data_y, 
#                                                                             test_size=0.2, 
#                                                                             shuffle=True, 
#                                                                             random_state=39)
#     print('-------------------------------------00000000-----')
#     print(data_test_y[5])
#     print(data_test_y[10])
#     print(data_test_y[15])
#     print('-------------------------------------00000000-----')
#     if var_model == "ST-RF":
#         run_model = run_strf
#     elif var_model == "LSTM":
#         run_model = run_lstm
#     elif var_model == "bi-LSTM":
#         run_model = run_bilstm
#     elif var_model == "THAT":
#         run_model = run_that
#     elif var_model == "ResNet18":
#         run_model = run_resnet
        
#     print(var_model)
#     result = run_model(data_train_x, data_train_y, data_test_x, data_test_y, var_repeat)
#     result["model"] = var_model
#     result["task"] = var_task
#     result["data"] = preset["data"]
#     result["nn"] = preset["nn"]
#     print(result)
    
#     with open(preset["path"]["save"], 'w') as var_file:
#         json.dump(result, var_file, indent=4)

# if __name__ == "__main__":
#     print('start')
#     run()


In [None]:
import gc
import torch
gc.collect()           
torch.cuda.empty_cache()  
torch.cuda.ipc_collect()  




"""
[file]          run.py
[description]   run WiFi-based models and optionally save a multiclass confusion matrix
"""

import argparse
import json
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib.backends.backend_pdf import PdfPages

# from preset import preset, name_run
# from load_data import load_data_x, load_data_y, encode_data_y
# from lstm import run_lstm, LSTMM
# from bilstm import run_bilstm, BiLSTMM
# from that import run_that, THAT
# from resnet import run_resnet, ResNet18Model
# from strf import run_strf  # if you have the ST-RF implementation

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",   default=preset["model"],  type=str)
    parser.add_argument("--task",    default=preset["task"],   type=str)
    parser.add_argument("--repeat",  default=preset["repeat"], type=int)
    parser.add_argument("--save_cm", action="store_true",
                        help="Save a multiclass confusion matrix of the best model to PDF")
    args, _ = parser.parse_known_args()
    return args

def save_multiclass_confusion_matrix(model, data_loader, device, pdf_path, num_classes):
    """
    Given a model that outputs one-hot logits for a multiclass task,
    convert to predicted classes via argmax, then plot and save a
    num_classes × num_classes confusion matrix to pdf_path.
    """
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for xb, yb in data_loader:
            xb = xb.to(device)
            logits = model(xb)
            # predicted class is index of max logit
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            trues = torch.argmax(yb, dim=1).cpu().numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(trues.tolist())

    labels = list(range(num_classes))
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    disp = ConfusionMatrixDisplay(cm, display_labels=labels)
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(ax=ax, xticks_rotation="vertical")
    ax.set_title("Confusion Matrix")
    with PdfPages(pdf_path) as pdf:
        pdf.savefig(fig)
    plt.close(fig)

def run():
    args       = parse_args()
    var_model  = args.model
    var_task   = args.task
    var_repeat = args.repeat

    # --- Load and encode the data ---
    data_pd_y = load_data_y(
        preset["path"]["data_y"],
        var_environment=preset["data"]["environment"],
        var_wifi_band=preset["data"]["wifi_band"],
        var_num_users=preset["data"]["num_users"]
    )
    labels = data_pd_y["label"].tolist()
    data_x = load_data_x(preset["path"]["data_x"], labels)
    data_y = encode_data_y(data_pd_y, var_task)

    train_x, test_x, train_y, test_y = train_test_split(
        data_x, data_y, test_size=0.2, shuffle=True, random_state=39
    )

    # --- Select which model runner to use ---
    if var_model == "ST-RF":
        from strf import run_strf
        run_model = run_strf
    elif var_model == "LSTM":
        run_model = run_lstm
    elif var_model == "bi-LSTM":
        run_model = run_bilstm
    elif var_model == "THAT":
        run_model = run_that
    elif var_model == "ResNet18":
        run_model = run_resnet
    else:
        raise ValueError(f"Unknown model: {var_model}")

    # --- Train and evaluate ---
    print(f"Running model: {var_model}")
    result = run_model(train_x, train_y, test_x, test_y, var_repeat)
    result["model"] = var_model
    result["task"]  = var_task
    result["data"]  = preset["data"]
    result["nn"]    = preset["nn"]
    print(result)

    # --- Save results to JSON ---
    with open(preset["path"]["save"], "w") as f:
        json.dump(result, f, indent=4)

    # --- Optionally save a multiclass confusion matrix ---
    # if args.save_cm:
    if Confusion_matrix == 1:
        # 1) completely release GPU memory used for training
        del run_model                      # if 'model' from training is still in scope
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
        # 2) reshape input only if the network is sequence‑based
        if var_model in ("LSTM", "bi-LSTM", "THAT"):
            test_x_cm = test_x.reshape(test_x.shape[0], test_x.shape[1], -1)
        else:                           # ResNet18, ST‑RF
            test_x_cm = test_x
    
        # 3) build the *same* architecture on CPU and load its weights
        device_cm = torch.device("cpu")
        if var_model == "LSTM":
            model_cm = LSTMM(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
        elif var_model == "bi-LSTM":
            model_cm = BiLSTMM(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
        elif var_model == "THAT":
            model_cm = THAT(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
        elif var_model == "ResNet18":
            model_cm = ResNet18Model(test_x_cm[0].shape, test_y[0].shape).to(device_cm)
        else:
            raise ValueError(f"Confusion matrix not supported for {var_model}")
    
        best_path = f"/kaggle/working/{name_run}_best_model.pt"
        model_cm.load_state_dict(torch.load(best_path, map_location=device_cm))
        model_cm.eval()
    
        # 4) DataLoader on CPU with a safe batch size
        test_ds = TensorDataset(torch.from_numpy(test_x_cm).float(),
                                torch.from_numpy(test_y).float())
        test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
    
        # 5) save the confusion matrix PDF
        num_classes = test_y.shape[1]
        pdf_name = f"{name_run}_confusion_matrix.pdf"
        save_multiclass_confusion_matrix(model_cm,test_loader,device_cm,pdf_name,num_classes)
        print(f"✅ Saved confusion matrix (classes 0–{num_classes-1}) to {pdf_name}")
if __name__ == "__main__":
    print("start")
    run()


In [None]:
import gc
import torch
import shutil
import json
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib.backends.backend_pdf import PdfPages

gc.collect()           
torch.cuda.empty_cache()  
torch.cuda.ipc_collect()

# ---------- helper: save multiclass confusion matrix ------------------
def save_multiclass_confusion_matrix(model, data_loader, pdf_path, num_classes):
    """
    Forward‑pass on CPU, collect predictions, and write an N×N confusion matrix
    to a single‑page PDF (pdf_path).
    """
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for xb, yb in data_loader:
            logits = model(xb.cpu())                       # ensure CPU
            preds  = torch.argmax(logits, dim=1).numpy()
            trues  = torch.argmax(yb, dim=1).numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(trues.tolist())

    labels = list(range(num_classes))
    cm  = confusion_matrix(y_true, y_pred, labels=labels)
    disp = ConfusionMatrixDisplay(cm, display_labels=labels)
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(ax=ax, xticks_rotation="vertical")
    ax.set_title("Few‑shot Confusion Matrix")
    with PdfPages(pdf_path) as pdf:
        pdf.savefig(fig)
    plt.close(fig)

# -------------------- pick run_* function ------------------------------
if preset["model"] == "ST-RF":
    run_model = run_strf
elif preset["model"] == "LSTM":
    run_model = run_lstm
elif preset["model"] == "bi-LSTM":
    run_model = run_bilstm
elif preset["model"] == "THAT":
    run_model = run_that
elif preset["model"] == "ResNet18":
    run_model = run_resnet
else:
    raise ValueError(f"No few‑shot implementation for {preset['model']}.")

# ------------------------ load / split data ----------------------------
data_pd_y = load_data_y(preset["path"]["data_y"],
                        var_environment=[dest_env],
                        var_wifi_band=preset["data"]["wifi_band"],
                        var_num_users=preset["data"]["num_users"])

labels_list = data_pd_y["label"].tolist()
data_x = load_data_x(preset["path"]["data_x"], labels_list)
data_y = encode_data_y(data_pd_y, preset["task"])

train_x, test_x, train_y, test_y = train_test_split(
    data_x, data_y, test_size=0.2, shuffle=True, random_state=39)

# Few-shot sample size
train_x = train_x[:few_shot_num_samples]
train_y = train_y[:few_shot_num_samples]

# ----------------------- few‑shot training -----------------------------
original_epochs = preset["nn"]["epoch"]
preset["nn"]["epoch"] = few_shot_epochs

# Load the best model weights
best_model_path = f"{name_run}_best_model.pt"

# Initialize the model
if preset["model"] == "LSTM":
    model = LSTMM(train_x[0].shape, train_y[0].shape)  # Replace with your model initialization
elif preset["model"] == "bi-LSTM":
    model = BiLSTMM(train_x[0].shape, train_y[0].shape)  # Replace with your model initialization
elif preset["model"] == "THAT":
    model = THAT(train_x[0].shape, train_y[0].shape)  # Replace with your model initialization
elif preset["model"] == "ResNet18":
    model = ResNet18Model(train_x[0].shape, train_y[0].shape)  # Replace with your model initialization
else:
    raise ValueError(f"Model {preset['model']} not supported!")

# Load the weights into the model
model.load_state_dict(torch.load(best_model_path, map_location="cpu"))

# Fine-tune the model on few-shot data (note: `run_model` should now return only the result)
result = run_model(train_x, train_y, test_x, test_y, var_repeat=1)
print(result)

# --------------------- save few‑shot checkpoints -----------------------
# After fine-tuning, save the model
torch.save(model.state_dict(), f"{name_run}_fewshot_final_model.pt")
torch.save(model.state_dict(), f"{name_run}_fewshot_best_model.pt")

# ------------------- confusion matrix on CPU ---------------------------
if Confusion_matrix == 1 and preset["model"] != "ST-RF":

    # reshape for sequence models
    test_x_rs = (test_x.reshape(test_x.shape[0], test_x.shape[1], -1)
                 if preset["model"] in ("LSTM", "bi-LSTM", "THAT") else test_x)

    # instantiate identical architecture on CPU
    if preset["model"] == "LSTM":
        model_cpu = LSTMM(test_x_rs[0].shape, test_y[0].shape).cpu()
    elif preset["model"] == "bi-LSTM":
        model_cpu = BiLSTMM(test_x_rs[0].shape, test_y[0].shape).cpu()
    elif preset["model"] == "THAT":
        model_cpu = THAT(test_x_rs[0].shape, test_y[0].shape).cpu()
    else:  # ResNet18
        model_cpu = ResNet18Model(test_x_rs[0].shape, test_y[0].shape).cpu()

    # load weights
    model_cpu.load_state_dict(torch.load(f"{name_run}_fewshot_best_model.pt", map_location="cpu"))

    # CPU DataLoader with a safe batch size
    test_ds = TensorDataset(torch.from_numpy(test_x_rs).float(),
                            torch.from_numpy(test_y).float())
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

    pdf_name = f"{name_run}_fewshot_confusion_matrix.pdf"
    num_classes = test_y.shape[1]
    save_multiclass_confusion_matrix(model_cpu, test_loader, pdf_name, num_classes)
    print(f"✅ Saved few‑shot confusion matrix (classes 0–{num_classes-1}) to {pdf_name}")

# ----------------------- restore & persist -----------------------------
preset["nn"]["epoch"] = original_epochs

# Save the final result to JSON
with open("result_fewshot.json", "w") as f:
    json.dump(result, f, indent=4)
