In [None]:
import torch # we need pytorch installed
from models import VAE # import the VAE model
import numpy as np
import matplotlib.pyplot as plt
import pandas
import json
from scipy.io import savemat

from scipy import stats
import scipy
import datetime
import os
import time
from itertools import product
import random

# Suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch')

# Make determinsitic
SEED = 42  # or any fixed number

# Python, NumPy, PyTorch seed
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# For deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Optional: ensure reproducibility across runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # Needed for CUDA <11.2


#%% SETTINGS
DataType = 'cn'
LatentDims = [6,7,8]
HiddenDimsList = [np.array([64,48,32,16]),
                  np.array([64,32,16]),
                  np.array([32,16,8]),
                  np.array([48,32,16]),]
Betas = [0.0013, 0.0014, 0.0015, 0.0016, 0.0017, 0.0018]
BatchSize = 128
Epochs = 5000
tag = 'UOP_uniform_pGRAM' # UOP_near_crash_steeper  UOP_near_crash  UOP_inc_lit_disps

# load training data
json_file = f'/Users/gracecalkins/Local_Documents/local_code/pipag_training/data/UOP_uniform_pGRAM_2000_data_energy_scaled_downsampled_.json'
# json_file = f'/Users/gracecalkins/Local_Documents/local_code/pipag_training/data/UOP_near_crash_steeper_near_escape_COMBINED_5000_data_energy_scaled_downsampled_.json'

date = datetime.datetime.now().strftime('%m%d%H%M%S')


# Create parent 'figs' folder if it doesn't exist
figPath = './VAE_arch_eval_round2'
os.makedirs(figPath, exist_ok=True)


# set device
Dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load the JSON file
with open(json_file, 'r') as f:
    dataset = json.load(f)

# dataset is a dict: sample0, sample1, etc.
# Collect energy arrays where label == 0 (capture scenarios)
filtered_data = [entry['energy'] for entry in dataset.values() if entry['label'] == 0]

# Convert to numpy array
data = np.array(filtered_data)

n_data, k_data = data.shape # define number of samples and features

x_trn = torch.from_numpy(data).type(torch.FloatTensor).to(Dev)
x_trn_cpu = x_trn.cpu()

print('Loaded {} samples with label 0.'.format(len(filtered_data)))
print('total loss function calls: ', len(data)/BatchSize*Epochs)


# Loop over latent dimensions, hidden dimensions, and betas
numCases = len(LatentDims) * len(HiddenDimsList) * len(Betas)
print('Number of cases: ', numCases)

sum_time = 0
complete = 0

# Loop over all combinations of latent dimensions, hidden dimensions, and betas
for LatentDim in LatentDims:
    for HiddenDimsArr in HiddenDimsList:
        for Beta in Betas:
            tic = time.time()
            HiddenDims = HiddenDimsArr.tolist()
            print(f"LD: {LatentDim}, HD: {HiddenDims}, Beta: {Beta}")

            #%% create directory for saving figures
            # Create a directory for the current configuration
            suffix = f'{tag}_{DataType}_dim{LatentDim}_'
            for i, d in enumerate(HiddenDims):
                suffix += str(HiddenDims[i])
                if i < len(HiddenDims) - 1:
                    suffix += 'x'
            suffix += f'_beta{Beta}_batch{BatchSize}_epochs{Epochs}_{date}'


            #%% train
            vae_model = VAE(k_data,
                            latent_dim = LatentDim, # define the latent dimension of the auto-encoder
                            hidden_dims = HiddenDims, # define the strucute of the encoder and decoder
                            beta = Beta,
                            dev = Dev
                        )
            vae_model.to(Dev)
            epoch_loss, rec_loss, kld_loss = vae_model.train(x_trn,
                                                        batch_size = BatchSize,
                                                        epochs = Epochs
                                                        ) # train the vae model using synthetic data
            

            #%% plots
            # total loss
            fig, ax = plt.subplots()
            ax.grid()
            ax.plot(epoch_loss, '.')
            ax.set_ylim(-1, 30)
            plt.savefig(os.path.join(figPath, f'loss_{suffix}.png'))

            # loss sources
            fig, ax = plt.subplots()
            ax.grid()
            ax.plot(rec_loss, '.', label = 'reconstruction')
            ax.plot(kld_loss, '.', label = 'kld')
            ax.legend()
            plt.savefig(os.path.join(figPath, f'lossTerms_{suffix}.png'))

            # compare variance
            big_samp = vae_model.sample(num_samples = n_data).cpu()
            fig, ax = plt.subplots()
            ax.grid()
            ax.plot(stats.describe(x_trn_cpu, axis = 0)[3], label = 'Training Variance')
            ax.plot(stats.describe(big_samp.detach(), axis = 0)[3], label = 'VAE Variance')
            ax.legend()
            plt.savefig(os.path.join(figPath, f'var_{suffix}.png'))
            plt.close('all')

            toc = time.time()
            run_time = toc - tic
            avg_time = (sum_time+run_time)/(complete+1)
            time_left = (numCases-complete)*avg_time
            sum_time += run_time
            complete += 1
            print(f"Estimated time left: {time_left/60:.2f} minutes")

Loaded 1547 samples with label 0.
total loss function calls:  60429.6875
Number of cases:  72
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0013


  8%|â–Š         | 384/5000 [00:09<01:52, 41.10 Epoch/s, loss=0.0217]