<h1> Training </h1>

In [2]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd
import import_ipynb



from config_setup import TrainingConfig
from model import VAE

In [3]:
def train_vae(config: TrainingConfig, vae: VAE, train_loader: DataLoader):
    
    device = torch.device(config.device)
    vae.to(device)
    optimizer = optim.Adam(vae.parameters(), lr=config.learning_rate)

    start_epoch = 0
    best_loss = float('inf')
    epochs_no_improve = 0

    log_path = os.path.join(config.current_log_dir, "training_log.csv")
    if os.path.exists(log_path) and config.load_pretrained_if_exists:
        log_df = pd.read_csv(log_path)
        if not log_df.empty:
            start_epoch = log_df["epoch"].iloc[-1] + 1
            best_loss = log_df[config.early_stopping_monitor].min()
    else:
        log_df = pd.DataFrame(columns=["epoch", "total_loss", "reconstruction_loss", "kl_loss"])

    if config.load_pretrained_if_exists:
        checkpoint_file = os.path.join(config.current_checkpoint_dir, "vae_best.safetensors")
        if os.path.exists(checkpoint_file):
            print(f"Loading model from {checkpoint_file}")
            
            state_dict = load_file(checkpoint_file, device=str(device))
            vae.load_state_dict(state_dict)
            print(f"Model loaded. Resuming from epoch {start_epoch}.")
        else:
            print(f"No safetensors checkpoint found. Training from scratch.")

    for epoch in range(start_epoch, config.epochs):
        
        current_tau = 1.0

        if config.latent_distribution == 'gumbel':
            anneal_params = config.gumbel_annealing
            initial_tau = anneal_params['initial_temperature']
            min_tau = anneal_params['min_temperature']
            rate = anneal_params['annealing_rate']

            current_tau = max(initial_tau * math.exp(-rate * epoch), min_tau) 
            print(f"Epoch {epoch+1} | Gumbel Temperature (tau): {current_tau:.4f}")

        
        vae.train()
        epoch_losses = {"total_loss": 0, "reconstruction_loss": 0, "kl_loss": 0}
        
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            
            
            model_output = vae(data, temperature=current_tau)
            
            
            losses = vae.loss_function(data, model_output)
            
            losses["total_loss"].backward()
            optimizer.step()
            for k, v in losses.items():
                epoch_losses[k] += v.item()

        avg_losses = {k: v / len(train_loader) for k, v in epoch_losses.items()}
        print(f"--- Epoch {epoch+1} | Avg Loss: {avg_losses['total_loss']:.4f} ---")

        new_log_row = pd.DataFrame([{"epoch": epoch, **avg_losses}])
        log_df = pd.concat([log_df, new_log_row], ignore_index=True)
        log_df.to_csv(log_path, index=False)

        current_loss = avg_losses[config.early_stopping_monitor]
        if current_loss < best_loss - config.early_stopping_min_delta:
            best_loss = current_loss
            epochs_no_improve = 0
            save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, "vae_best.safetensors"))
            print(f"New best loss: {best_loss:.4f}. Checkpoint saved as vae_best.safetensors")
        else:
            epochs_no_improve += 1
        
        if epoch > 0 and epoch % config.save_every_n_epochs == 0:
            # Correzione del nome del file nel messaggio di log
            checkpoint_name = f"vae_epoch_{epoch}.safetensors"
            save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, checkpoint_name))
            print(f"Checkpoint saved for epoch {epoch} as {checkpoint_name}")

        if epochs_no_improve >= config.early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
            
    save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, "vae_final.safetensors"))
    print("Training complete. Final model saved as vae_final.safetensors")


In [None]:
def generate_and_save_latent_space(config: TrainingConfig, vae: VAE, train_loader: DataLoader):
    device = torch.device(config.device)
    vae.to(device)
    vae.eval()

    dataset= train_loader.dataset

    inference_loader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=0,  
        pin_memory=True
    )

    checkpoint_path = os.path.join(config.current_checkpoint_dir, "vae_best.safetensors")
    if not os.path.exists(checkpoint_path):
        checkpoint_path = os.path.join(config.current_checkpoint_dir, "vae_final.safetensors")
    
    if os.path.exists(checkpoint_path):
        print(f"Loading model for inference from {checkpoint_path}")
        state_dict = load_file(checkpoint_path, device=str(device))
        vae.load_state_dict(state_dict)
    else:
        print("WARNING: No safetensors model found for inference.")

    inference_tau = 1.0
    if config.latent_distribution == 'gumbel':
        inference_tau = config.gumbel_annealing['min_temperature']

    all_labels = []
    latent_data_list = []

    with torch.no_grad():
        for data, labels in inference_loader:
            data = data.to(device)
            model_output = vae(data, temperature=inference_tau)
            
            if config.latent_distribution == 'gaussian':
                # For the Gaussian VAE, we save the mean as a latent representation
                latent_data_list.append(model_output['z_mean'].cpu().numpy())
            elif config.latent_distribution == 'gumbel':
                # For VAE Gumbel, let's save the one-hot sample
                latent_data_list.append(model_output['z'].cpu().numpy())
                
            all_labels.append(labels.cpu().numpy())

    z_data_combined = np.concatenate(latent_data_list, axis=0)
    labels_np = np.concatenate(all_labels, axis=0)

    
    dist_tag = config.latent_distribution
    latent_space_dir_name = f"latent_space_dataset_{dist_tag}"
    latent_space_dir = os.path.join(config.base_data_dir, latent_space_dir_name)
    os.makedirs(latent_space_dir, exist_ok=True)

    
        
    
    tensors_to_save = {
        "latent_space": torch.from_numpy(z_data_combined),
        "labels": torch.from_numpy(labels_np)
    }
    

    output_filename = os.path.join(latent_space_dir, f"{config.current_run_name}.safetensors")
    

    save_file(tensors_to_save, output_filename)
    
    print(f"Latent space data saved to: {output_filename}")