<h1> VAE training and processing</h1>

<h3>Sample code to train a new VAE and run the CSI processing<h3>


In [None]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd

import import_ipynb

from config_setup import TrainingConfig
from data_utils import download_and_prepare_data, CsiPyTorchDataset,set_seed
from model import VAE
from training import train_vae, generate_and_save_latent_space

In [None]:
if __name__ == "__main__":

    config = TrainingConfig(config_path='config.json5')
    

    params=config.experiment_parameters
    window_sizes_to_test = params['window_sizes']
    antennas_to_test = params['antennas_to_test']
    random_seeds = params['random_seeds']


    for window_size in window_sizes_to_test:
        for antenna_idx_list in antennas_to_test:
            for seed in random_seeds:

                config.update_run_config(window_size=window_size, antenna_indices=antenna_idx_list, seed=seed)
                
                print(f"--- START EXPERIMENT: {config.current_run_name} ---")
                
                
                set_seed(seed)
               


                try:
                    file_list = download_and_prepare_data(config)
                except FileNotFoundError as e:
                    print(e)
                    continue

                
            

                try:
                    csi_dataset = CsiPyTorchDataset(config, file_list)
                    if len(csi_dataset) == 0:
                        print("ERROR: Dataset is empty. Terminating.")
                        exit()


                    train_loader =  DataLoader(
                        csi_dataset, 
                        batch_size=config.batch_size, 
                        shuffle=True,
                        num_workers=0 
                    )
                except Exception as e:
                    import traceback
                    print(f"Error during dataset/dataloader creation: {e}")
                    traceback.print_exc()
                    exit()

                vae_model = VAE(config)
                print(f"Model created. Total parameter: {sum(p.numel() for p in vae_model.parameters())}")
                
                print("\n--- Starting Training ---")
                train_vae(config, vae_model, train_loader)

                print("\n--- Starting Latent Space Generation ---")
                generate_and_save_latent_space(config, vae_model, train_loader)

                print(f"--- END OF EXPERIMENT: {config.current_run_name} ---")