<h1> Training </h1>

In [13]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd
import import_ipynb



from config_setup import TrainingConfig
from model import VAE

In [14]:
def train_vae(config: TrainingConfig, vae: VAE, train_loader: DataLoader):
    device = torch.device(config.device)
    vae.to(device)
    optimizer = optim.Adam(vae.parameters(), lr=config.learning_rate)

    start_epoch = 0
    best_loss = float('inf')
    epochs_no_improve = 0

    log_path = os.path.join(config.current_log_dir, "training_log.csv")
    if os.path.exists(log_path) and config.load_pretrained_if_exists:
        log_df = pd.read_csv(log_path)
        if not log_df.empty:
            start_epoch = log_df["epoch"].iloc[-1] + 1
            best_loss = log_df[config.early_stopping_monitor].min()
    else:
        log_df = pd.DataFrame(columns=["epoch", "total_loss", "reconstruction_loss", "kl_loss"])

    if config.load_pretrained_if_exists:
        checkpoint_file = os.path.join(config.current_checkpoint_dir, "vae_best.safetensors")
        if os.path.exists(checkpoint_file):
            print(f"Loading model from {checkpoint_file}")
            load_file(vae, checkpoint_file, device=str(device))
            print(f"Model loaded. Resuming from epoch {start_epoch}.")
        else:
            print(f"No safetensors checkpoint found. Training from scratch.")

    for epoch in range(start_epoch, config.epochs):
        vae.train()
        epoch_losses = {"total_loss": 0, "reconstruction_loss": 0, "kl_loss": 0}
        
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            x_reconstructed, z_mean, z_log_var, _ = vae(data)
            losses = vae.loss_function(data, x_reconstructed, z_mean, z_log_var)
            losses["total_loss"].backward()
            optimizer.step()
            for k, v in losses.items():
                epoch_losses[k] += v.item()

        avg_losses = {k: v / len(train_loader) for k, v in epoch_losses.items()}
        print(f"--- Epoch {epoch+1} | Avg Loss: {avg_losses['total_loss']:.4f} ---")

        new_log_row = pd.DataFrame([{"epoch": epoch, **avg_losses}])
        log_df = pd.concat([log_df, new_log_row], ignore_index=True)
        log_df.to_csv(log_path, index=False)

        current_loss = avg_losses[config.early_stopping_monitor]
        if current_loss < best_loss - config.early_stopping_min_delta:
            best_loss = current_loss
            epochs_no_improve = 0
            save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, "vae_best.safetensors"))
            print(f"New best loss: {best_loss:.4f}. Checkpoint saved as vae_best.safetensors")
        else:
            epochs_no_improve += 1
        
        if epoch > 0 and epoch % config.save_every_n_epochs == 0:
            save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, f"vae_epoch_{epoch}.safetensors"))
            print(f"Checkpoint saved for epoch {epoch} as  .safetensors")

        if epochs_no_improve >= config.early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
            
    save_file(vae.state_dict(), os.path.join(config.current_checkpoint_dir, "vae_final.safetensors"))
    print("Training complete. Final model saved as vae_final.safetensors")


In [15]:
def generate_and_save_latent_space(config: TrainingConfig, vae: VAE, data_loader: DataLoader):
    device = torch.device(config.device)
    vae.to(device)
    vae.eval()

    checkpoint_path = os.path.join(config.current_checkpoint_dir, "vae_best.safetensors")
    if not os.path.exists(checkpoint_path):
        checkpoint_path = os.path.join(config.current_checkpoint_dir, "vae_final.safetensors")
    
    if os.path.exists(checkpoint_path):
        print(f"Loading model for inference from {checkpoint_path}")
        load_file(vae, checkpoint_path, device=str(device))
    else:
        print("WARNING: No safetensors model found for inference.")

    all_z_mean, all_z_log_var, all_labels = [], [], []
    with torch.no_grad():
        for data, labels in data_loader:
            data = data.to(device)
            z_mean, z_log_var, _ = vae.encoder(data)
            all_z_mean.append(z_mean.cpu().numpy())
            all_z_log_var.append(z_log_var.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    z_data_combined = np.concatenate([np.concatenate(all_z_mean), np.concatenate(all_z_log_var)], axis=1)
    labels_np = np.concatenate(all_labels)

    latent_space_dir = os.path.join(config.base_data_dir, "S1/latent_space_dataset")
    os.makedirs(latent_space_dir, exist_ok=True)
    
    output_filename = os.path.join(latent_space_dir, f"{config.current_run_name}.pkl")
    with open(output_filename, 'wb') as f:
        pickle.dump([z_data_combined, labels_np], f)
    print(f"Latent space data saved to: {output_filename}")

