<h1> Data Utils </h1>

In [1]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd
import import_ipynb
import random

from config_setup import TrainingConfig

<h3> Download and Prepare Data </h3>

In [None]:
def download_and_prepare_data(config: TrainingConfig) -> List[str]:

    """
    It checks, downloads and extracts the data. Then, it builds and returns
    the list of .mat files to use for training.
    """
    zip_path = os.path.join(config.base_data_dir, f"{config.dataset_name}.zip")
    
    example_mat_file = os.path.join(config.dataset_path, f"{config.dataset_name}a_A.mat")
    if not os.path.exists(example_mat_file):
        if not os.path.exists(zip_path):
            print(f"Downloading data from {config.raw_data_zip_url}...")
            wget.download(config.raw_data_zip_url, zip_path)
            print("Download complete.")
        
        print(f"Extracting data to {config.dataset_path}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(config.dataset_path)
        print("Extraction complete.")
    else:
        print(f"Data already found in {config.dataset_path}")

    file_list = []
    receivers = config.receivers if isinstance(config.receivers, list) else [config.receivers]
    activities = config.activities if isinstance(config.activities, list) else [config.activities]

    print(f"Building file list for scenario {config.dataset_name}...")
    print(f"Receivers: {receivers}, Activities: {activities}")

    for receiver_id in receivers:
        for activity_id in activities:
            file_name = f"{config.dataset_name}{receiver_id}_{activity_id}.mat"
            full_path = os.path.join(config.dataset_path, file_name)
            
            if os.path.exists(full_path):
                file_list.append(full_path)
            else:
                print(f"Warning: File not found and will be skipped: {full_path}")
    
    if not file_list:
        raise FileNotFoundError("ERROR: No files found for the specified configuration.")
        
    print(f"Generated file list for the dataset ({len(file_list)} files): {file_list}")
    
    return file_list


<h3> Prepare CSI Dataset </h3>

In [3]:
class CsiPyTorchDataset(Dataset):
    def __init__(self, config: TrainingConfig, file_list: List[str]):
        self.config = config
        self.window_size = config.window_size
        self.samples_per_file = config.samples_per_file
        self.antenna_indices = config.antenna_indices
        self.input_channels = config.input_channels

        self.all_csi_segments = []
        self.all_labels_for_windows = []
        self.all_start_indices_in_concatenated_csi = []

        current_concat_offset = 0
        print("Loading MAT files...")
        for activity_idx, file_path in enumerate(file_list):
            try:
                mat = sio.loadmat(file_path)
                data = np.array(mat['csi'])  # Shape (raw_samples, features, num_antennas_in_file)
            except Exception as e:
                print(f"Error loading {file_path}: {e}. Skipped.")
                continue

            num_raw_samples_in_file = data.shape[0]
            samples_to_take = min(self.samples_per_file, num_raw_samples_in_file)
            
            # --- Antenna selection logic ---
            max_antenna_idx = max(self.antenna_indices)
            if data.ndim < 3 or data.shape[2] <= max_antenna_idx:
                raise ValueError(f"Data in {file_path} is not compatible. Required antenna index {max_antenna_idx}, "
                                 f"but data shape is {data.shape}.")

            # Select the specified antennas and maintain the shape for subsequent processing
            selected_data = data[:samples_to_take, :, self.antenna_indices]
            
            # If only one antenna is selected, np.squeeze might remove the dimension. We restore it.
            if selected_data.ndim == 2:
                selected_data = np.expand_dims(selected_data, axis=2)

            data = np.round(np.abs(selected_data)).astype(np.float32)
            self.all_csi_segments.append(torch.from_numpy(data))

            num_possible_windows_this_file = data.shape[0] - self.window_size + 1
            if num_possible_windows_this_file <= 0:
                print(f"Warning: samples_per_file ({data.shape[0]}) in {file_path} is less than window_size ({self.window_size}).")
                continue
            
            for i in range(num_possible_windows_this_file):
                self.all_start_indices_in_concatenated_csi.append(current_concat_offset + i)
                self.all_labels_for_windows.append(activity_idx)

            current_concat_offset += data.shape[0]
            print(f"Processed  {file_path}, {num_possible_windows_this_file} windows added.")

        if not self.all_csi_segments:
            raise RuntimeError("No CSI data loaded. Check paths and MAT files.")

        self.csi_data_concatenated = torch.cat(self.all_csi_segments, dim=0)

        # Normalization of CSI data
        if self.csi_data_concatenated.numel() > 0:
            max_val = torch.max(self.csi_data_concatenated)
            if max_val > 0:
                self.csi_data_concatenated /= max_val
        
        print(f"Dataset initialized. CSI shape: {self.csi_data_concatenated.shape}")
        print(f"Total number of windows: {len(self)}")

    def __len__(self):
        return len(self.all_start_indices_in_concatenated_csi)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        actual_start_idx = self.all_start_indices_in_concatenated_csi[idx]
        window_data = self.csi_data_concatenated[actual_start_idx : actual_start_idx + self.window_size, ...]
        
         # PyTorch Conv2d expects (Batch, Channels, Height, Width)
        # Here: Channels = num_antennas, Height = window_size, Width = features
        # Permute from (window_size, features, channels) to (channels, window_size, features)
        
        window_data = window_data.permute(2, 0, 1)

        label = self.all_labels_for_windows[idx]
        return window_data, torch.tensor(label, dtype=torch.long)

<h3> Random Seed </h3>

In [4]:
def set_seed(seed: int):
    """Set the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False