In [11]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict, Any, Union
import pandas as pd

**Configurazione**

In [12]:
class TrainingConfig:
    def __init__(self, config_path: str = 'config.json5'):
        print(f"Caricamento configurazione da: {config_path}")
        with open(config_path, 'r') as f:
            config_data = json5.load(f)

        for key, value in config_data.items():
            setattr(self, key, value)
        
        self.__post_init__()

    def __post_init__(self):
        """Inizializza i campi calcolati e le directory dopo il caricamento."""
        if isinstance(self.antenna_indices, int):
            self.antenna_indices = [self.antenna_indices]
        
        if not self.antenna_indices:
            self.antenna_indices = list(range(4))
            print("ATTENZIONE: 'antenna_indices' non specificato, si useranno tutte e 4 le antenne.")

        self.input_channels = len(self.antenna_indices)
        
        if self.input_channels == 1:
            run_name_prefix = f'single_antenna_{self.antenna_indices[0]}'
        elif self.input_channels == 4:
            run_name_prefix = 'all_antennas'
        else:
            antennas_str = '_'.join(map(str, sorted(self.antenna_indices)))
            run_name_prefix = f'custom_antennas_{antennas_str}'
            
        self.current_run_name = f'{run_name_prefix}'

        self.current_checkpoint_dir = os.path.join(self.checkpoint_dir_base, self.current_run_name)
        self.current_log_dir = os.path.join(self.log_dir_base, self.current_run_name)
        self.dataset_path = os.path.join(self.base_data_dir, self.dataset_name, "dataset")

        os.makedirs(self.current_checkpoint_dir, exist_ok=True)
        os.makedirs(self.current_log_dir, exist_ok=True)
        os.makedirs(self.dataset_path, exist_ok=True)

        final_decoder_out_channels = self.input_channels
        self.decoder_conv_transpose_configs.append(
            {"out_channels": final_decoder_out_channels, "kernel_size": (1, 1), "stride": (1, 1), "padding": 0, "activation": "Sigmoid"}
        )

    @staticmethod
    def get_activation(activation_name: str) -> nn.Module:
        
        if activation_name is None:
            return nn.Identity()

        try:
            activation_class = getattr(nn, activation_name)
            return activation_class()
        except AttributeError:
            raise ValueError(f"Funzione di attivazione non trovata in torch.nn: '{activation_name}'")

    def __repr__(self):
        return f"TrainingConfig(run_name='{self.current_run_name}', device='{self.device}')"


**Download e Preparazione Dati**

In [13]:
def download_and_prepare_data(config: TrainingConfig):
    zip_path = os.path.join(config.base_data_dir, f"{config.dataset_name}.zip")
    
    example_mat_file = os.path.join(config.dataset_path, f"{config.dataset_name}a_A.mat")
    if os.path.exists(example_mat_file):
        print(f"Dati già trovati in {config.dataset_path}")
        return

    if not os.path.exists(zip_path):
        print(f"Download dati da {config.raw_data_zip_url}...")
        wget.download(config.raw_data_zip_url, zip_path)
        print("Download completato.")
    else:
        print(f"File ZIP trovato: {zip_path}")

    print(f"Estrazione dati in {os.path.join(config.base_data_dir, config.dataset_name)}...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(os.path.join(config.base_data_dir, config.dataset_name))
    print("Estrazione completata.")


**CLASSE DATASET**

In [14]:
class CsiPyTorchDataset(Dataset):
    def __init__(self, config: TrainingConfig, file_list: List[str]):
        self.config = config
        self.window_size = config.window_size
        self.samples_per_file = config.samples_per_file
        self.antenna_indices = config.antenna_indices
        self.input_channels = config.input_channels

        self.all_csi_segments = []
        self.all_labels_for_windows = []
        self.all_start_indices_in_concatenated_csi = []

        current_concat_offset = 0
        print("Caricamento file MAT...")
        for activity_idx, file_path in enumerate(file_list):
            try:
                mat = sio.loadmat(file_path)
                data = np.array(mat['csi'])  # Shape (raw_samples, features, num_antennas_in_file)
            except Exception as e:
                print(f"Errore nel caricare {file_path}: {e}. Saltato.")
                continue

            num_raw_samples_in_file = data.shape[0]
            samples_to_take = min(self.samples_per_file, num_raw_samples_in_file)
            
            # --- Logica di selezione antenna ---
            max_antenna_idx = max(self.antenna_indices)
            if data.ndim < 3 or data.shape[2] <= max_antenna_idx:
                raise ValueError(f"Dati in {file_path} non compatibili. Richiesto indice antenna {max_antenna_idx}, "
                                 f"ma la shape dei dati è {data.shape}.")

            # Seleziona le antenne specificate e mantieni la forma per la successiva elaborazione
            selected_data = data[:samples_to_take, :, self.antenna_indices]
            
            # Se viene selezionata una sola antenna, np.squeeze potrebbe rimuovere la dimensione. La ripristiniamo.
            if selected_data.ndim == 2:
                selected_data = np.expand_dims(selected_data, axis=2)

            data = np.round(np.abs(selected_data)).astype(np.float32)
            self.all_csi_segments.append(torch.from_numpy(data))

            num_possible_windows_this_file = data.shape[0] - self.window_size + 1
            if num_possible_windows_this_file <= 0:
                print(f"Attenzione: samples_per_file ({data.shape[0]}) in {file_path} è minore di window_size ({self.window_size}).")
                continue
            
            for i in range(num_possible_windows_this_file):
                self.all_start_indices_in_concatenated_csi.append(current_concat_offset + i)
                self.all_labels_for_windows.append(activity_idx)

            current_concat_offset += data.shape[0]
            print(f"Processato {file_path}, {num_possible_windows_this_file} finestre aggiunte.")

        if not self.all_csi_segments:
            raise RuntimeError("Nessun dato CSI caricato. Controlla i percorsi e i file MAT.")

        self.csi_data_concatenated = torch.cat(self.all_csi_segments, dim=0)

        # Normalizzazione
        if self.csi_data_concatenated.numel() > 0:
            max_val = torch.max(self.csi_data_concatenated)
            if max_val > 0:
                self.csi_data_concatenated /= max_val
        
        print(f"Dataset inizializzato. CSI shape: {self.csi_data_concatenated.shape}")
        print(f"Numero totale di finestre: {len(self)}")

    def __len__(self):
        return len(self.all_start_indices_in_concatenated_csi)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        actual_start_idx = self.all_start_indices_in_concatenated_csi[idx]
        window_data = self.csi_data_concatenated[actual_start_idx : actual_start_idx + self.window_size, ...]
        
        # PyTorch Conv2d si aspetta (Batch, Channels, Height, Width)
        # Qui: Channels = num_antennas, Height = window_size, Width = features
        # Permute da (window_size, features, channels) a (channels, window_size, features)
        window_data = window_data.permute(2, 0, 1)

        label = self.all_labels_for_windows[idx]
        return window_data, torch.tensor(label, dtype=torch.long)


**ENCODER**

In [15]:
class Encoder(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        layers = []
        current_channels = config.input_channels
        dummy_h, dummy_w = config.window_size, config.feature_dim 
        
        for layer_cfg in config.encoder_conv_configs:
            # Estrai i parametri per Conv2d, escludendo 'activation'
            conv_params = {k: v for k, v in layer_cfg.items() if k != 'activation'}
            layers.append(nn.Conv2d(current_channels, **conv_params))
            current_channels = layer_cfg["out_channels"]
            if "activation" in layer_cfg:
                # Chiama il metodo statico dalla classe TrainingConfig
                layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))
            
            kh, kw = layer_cfg["kernel_size"]
            sh, sw = layer_cfg["stride"]
            ph, pw = (layer_cfg.get("padding", 0),)*2 if isinstance(layer_cfg.get("padding", 0), int) else layer_cfg.get("padding", 0)
            dummy_h = math.floor((dummy_h + 2 * ph - kh) / sh + 1)
            dummy_w = math.floor((dummy_w + 2 * pw - kw) / sw + 1)

        layers.append(nn.Flatten())
        self.conv_to_flatten_shape = (current_channels, dummy_h, dummy_w)
        flattened_size = current_channels * dummy_h * dummy_w
        
        current_features = flattened_size 
        for layer_cfg in config.encoder_fc_configs:
            layers.append(nn.Linear(current_features, layer_cfg["out_features"]))
            current_features = layer_cfg["out_features"]
            if "activation" in layer_cfg:
                 # Chiama il metodo statico
                layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))
        
        self.model = nn.Sequential(*layers)
        self.fc_z_mean = nn.Linear(current_features, config.latent_dim)
        self.fc_z_log_var = nn.Linear(current_features, config.latent_dim)
    
    # ... il resto della classe Encoder rimane invariato ...
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x_encoded = self.model(x)
        z_mean = self.fc_z_mean(x_encoded)
        z_log_var = self.fc_z_log_var(x_encoded)
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        z = z_mean + eps * std
        return z_mean, z_log_var, z

**DECODER**

In [16]:
class Decoder(nn.Module):
    def __init__(self, config: TrainingConfig, encoder_conv_output_shape: Tuple[int, int, int]):
        super().__init__()
        self.config = config
        self.encoder_conv_output_shape = encoder_conv_output_shape
        layers = []
        decoder_start_features = np.prod(encoder_conv_output_shape)
        
        current_features = config.latent_dim
        if config.decoder_fc_configs:
            for layer_cfg in config.decoder_fc_configs:
                layers.append(nn.Linear(current_features, layer_cfg["out_features"]))
                current_features = layer_cfg["out_features"]
                if "activation" in layer_cfg:
                    # Chiama il metodo statico
                    layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))
            layers.append(nn.Linear(current_features, decoder_start_features))
        else:
            layers.append(nn.Linear(config.latent_dim, decoder_start_features))

        if config.encoder_fc_configs:
             # Chiama il metodo statico
             layers.append(TrainingConfig.get_activation(config.encoder_fc_configs[-1]['activation']))

        self.fc_part = nn.Sequential(*layers)
        
        conv_transpose_layers = []
        current_channels = self.encoder_conv_output_shape[0]
        for i, layer_cfg in enumerate(config.decoder_conv_transpose_configs):
            conv_params = {k: v for k, v in layer_cfg.items() if k != 'activation'}
            conv_transpose_layers.append(nn.ConvTranspose2d(current_channels, **conv_params))
            current_channels = layer_cfg["out_channels"]
            if "activation" in layer_cfg:
                 # Chiama il metodo statico
                 conv_transpose_layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))

        self.conv_transpose_part = nn.Sequential(*conv_transpose_layers)
        
    # ... il resto della classe Decoder rimane invariato ...
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        x = self.fc_part(z)
        x = x.view(-1, *self.encoder_conv_output_shape)
        x_reconstructed = self.conv_transpose_part(x)
        return x_reconstructed

**VAE**

In [17]:
class VAE(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config, self.encoder.conv_to_flatten_shape)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        z_mean, z_log_var, z = self.encoder(x)
        x_reconstructed = self.decoder(z)
        return x_reconstructed, z_mean, z_log_var, z

    def loss_function(self, x_original: torch.Tensor, x_reconstructed: torch.Tensor, 
                      z_mean: torch.Tensor, z_log_var: torch.Tensor) -> Dict[str, torch.Tensor]:
        bce_loss = nn.functional.binary_cross_entropy(x_reconstructed, x_original, reduction='none')
        reconstruction_loss = torch.mean(torch.sum(bce_loss, dim=(1, 2, 3)))
        kl_loss = torch.mean(-0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp(), dim=1))
        total_loss = reconstruction_loss + kl_loss
        return {"total_loss": total_loss, "reconstruction_loss": reconstruction_loss, "kl_loss": kl_loss}


**TRAINING LOOP**

In [18]:
def train_vae(config: TrainingConfig, vae: VAE, train_loader: DataLoader):
    device = torch.device(config.device)
    vae.to(device)
    optimizer = optim.Adam(vae.parameters(), lr=config.learning_rate)

    start_epoch = 0
    best_loss = float('inf')
    epochs_no_improve = 0

    log_path = os.path.join(config.current_log_dir, "training_log.csv")
    if os.path.exists(log_path) and config.load_pretrained_if_exists:
        log_df = pd.read_csv(log_path)
        if not log_df.empty:
            start_epoch = log_df["epoch"].iloc[-1] + 1
            best_loss = log_df[config.early_stopping_monitor].min()
    else:
        log_df = pd.DataFrame(columns=["epoch", "total_loss", "reconstruction_loss", "kl_loss"])

    if config.load_pretrained_if_exists:
        checkpoint_file = os.path.join(config.current_checkpoint_dir, "vae_best.safetensors")
        if os.path.exists(checkpoint_file):
            print(f"Caricamento modello da {checkpoint_file}")
            load_file(vae, checkpoint_file, device=str(device))
            print(f"Modello caricato. Riprendendo da epoch {start_epoch}.")
        else:
            print(f"Nessun checkpoint .safetensors trovato. Addestramento da zero.")

    for epoch in range(start_epoch, config.epochs):
        vae.train()
        epoch_losses = {"total_loss": 0, "reconstruction_loss": 0, "kl_loss": 0}
        
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            x_reconstructed, z_mean, z_log_var, _ = vae(data)
            losses = vae.loss_function(data, x_reconstructed, z_mean, z_log_var)
            losses["total_loss"].backward()
            optimizer.step()
            for k, v in losses.items():
                epoch_losses[k] += v.item()

        avg_losses = {k: v / len(train_loader) for k, v in epoch_losses.items()}
        print(f"--- Epoch {epoch+1} | Avg Loss: {avg_losses['total_loss']:.4f} ---")

        new_log_row = pd.DataFrame([{"epoch": epoch, **avg_losses}])
        log_df = pd.concat([log_df, new_log_row], ignore_index=True)
        log_df.to_csv(log_path, index=False)

        current_loss = avg_losses[config.early_stopping_monitor]
        if current_loss < best_loss - config.early_stopping_min_delta:
            best_loss = current_loss
            epochs_no_improve = 0
            save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, "vae_best.safetensors"))
            print(f"Nuovo best loss: {best_loss:.4f}. Checkpoint salvato come vae_best.safetensors")
        else:
            epochs_no_improve += 1
        
        if epoch > 0 and epoch % config.save_every_n_epochs == 0:
            save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, f"vae_epoch_{epoch}.safetensors"))
            print(f"Checkpoint salvato per l'epoca {epoch} come .safetensors")

        if epochs_no_improve >= config.early_stopping_patience:
            print(f"Early stopping attivato dopo {epoch+1} epoche.")
            break
            
    save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, "vae_final.safetensors"))
    print("Addestramento completato. Modello finale salvato come vae_final.safetensors")



**LATENT SPACE**

In [19]:
def generate_and_save_latent_space(config: TrainingConfig, vae: VAE, data_loader: DataLoader):
    device = torch.device(config.device)
    vae.to(device)
    vae.eval()

    checkpoint_path = os.path.join(config.current_checkpoint_dir, "vae_best.safetensors")
    if not os.path.exists(checkpoint_path):
        checkpoint_path = os.path.join(config.current_checkpoint_dir, "vae_final.safetensors")
    
    if os.path.exists(checkpoint_path):
        print(f"Caricamento modello per inferenza da {checkpoint_path}")
        load_file(vae, checkpoint_path, device=str(device))
    else:
        print("ATTENZIONE: Nessun modello .safetensors trovato per l'inferenza.")

    all_z_mean, all_z_log_var, all_labels = [], [], []
    with torch.no_grad():
        for data, labels in data_loader:
            data = data.to(device)
            z_mean, z_log_var, _ = vae.encoder(data)
            all_z_mean.append(z_mean.cpu().numpy())
            all_z_log_var.append(z_log_var.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    z_data_combined = np.concatenate([np.concatenate(all_z_mean), np.concatenate(all_z_log_var)], axis=1)
    labels_np = np.concatenate(all_labels)

    latent_space_dir = os.path.join(config.base_data_dir, "latent_space_dataset_pytorch")
    os.makedirs(latent_space_dir, exist_ok=True)
    
    output_filename = os.path.join(latent_space_dir, f"{config.current_run_name}.pkl")
    with open(output_filename, 'wb') as f:
        pickle.dump([z_data_combined, labels_np], f)
    print(f"Dati dello spazio latente salvati in: {output_filename}")



**MAIN**

In [None]:
if __name__ == "__main__":
    config = TrainingConfig(config_path='config.json5')
    print(f"Configurazione caricata: {config}")

    download_and_prepare_data(config)
    
    file_list = [os.path.join(config.dataset_path, f"{config.dataset_name}a_{x}.mat")
                 for x in string.ascii_uppercase[:config.num_activities]]
    print(f"Lista file per il dataset: {file_list}")

    try:
        csi_dataset = CsiPyTorchDataset(config, file_list)
        if len(csi_dataset) == 0:
            print("ERRORE: Dataset vuoto. Terminazione.")
            exit()
        train_loader = DataLoader(csi_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    except Exception as e:
        import traceback
        print(f"Errore fatale durante la creazione del dataset/dataloader: {e}")
        traceback.print_exc()
        exit()

    vae_model = VAE(config)
    print(f"Parametri totali del modello: {sum(p.numel() for p in vae_model.parameters())}")
    
    print("\n--- Inizio Addestramento ---")
    train_vae(config, vae_model, train_loader)

    print("\n--- Inizio Generazione Spazio Latente ---")
    generate_and_save_latent_space(config, vae_model, train_loader)

    print("\nProcesso completato.")

Caricamento configurazione da: config.json5
Configurazione caricata: TrainingConfig(run_name='custom_antennas_1_2', device='cpu')
Dati già trovati in ./data\S1\dataset
Lista file per il dataset: ['./data\\S1\\dataset\\S1a_A.mat', './data\\S1\\dataset\\S1a_B.mat', './data\\S1\\dataset\\S1a_C.mat', './data\\S1\\dataset\\S1a_D.mat', './data\\S1\\dataset\\S1a_E.mat']
Caricamento file MAT...
Processato ./data\S1\dataset\S1a_A.mat, 11551 finestre aggiunte.
Processato ./data\S1\dataset\S1a_B.mat, 11551 finestre aggiunte.
Processato ./data\S1\dataset\S1a_C.mat, 11551 finestre aggiunte.
Processato ./data\S1\dataset\S1a_D.mat, 11551 finestre aggiunte.
Processato ./data\S1\dataset\S1a_E.mat, 11551 finestre aggiunte.
Dataset inizializzato. CSI shape: torch.Size([60000, 2048, 2])
Numero totale di finestre: 57755
Parametri totali del modello: 185942

--- Inizio Addestramento ---
