In [9]:
import torch # we need pytorch installed
from models import VAE # import the VAE model
import numpy as np
import matplotlib.pyplot as plt
import pandas
import json
from scipy.io import savemat

from scipy import stats
import scipy
import datetime
import os
import time
from itertools import product
import random

# Suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch')

# Make determinsitic
SEED = 42  # or any fixed number

# Python, NumPy, PyTorch seed
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# For deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Optional: ensure reproducibility across runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # Needed for CUDA <11.2


#%% SETTINGS
DataType = 'cn'
LatentDims = [6,7,8]
HiddenDimsList = [np.array([64,48,32,16]),
                  np.array([64,32,16]),
                  np.array([32,16,8]),
                  np.array([48,32,16]),]
Betas = [0.0013, 0.0014, 0.0015, 0.0016, 0.0017, 0.0018]
BatchSize = 128
Epochs = 3000
tag = 'UOP_combined' # UOP_near_crash_steeper  UOP_near_crash  UOP_inc_lit_disps UOP_combined  UOP_uniform_pGRAM

# load training data
# json_file = f'/Users/gracecalkins/Local_Documents/local_code/pipag_training/data/UOP_uniform_pGRAM_2000_data_energy_scaled_downsampled_.json'
json_file = f'/Users/gracecalkins/Local_Documents/local_code/pipag_training/data/UOP_near_crash_steeper_near_escape_COMBINED_5000_data_energy_scaled_downsampled_.json'

date = datetime.datetime.now().strftime('%m%d%H%M%S')


# Create parent 'figs' folder if it doesn't exist
figPath = './VAE_arch_eval_combined'
os.makedirs(figPath, exist_ok=True)


# set device
Dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load the JSON file
with open(json_file, 'r') as f:
    dataset = json.load(f)

# dataset is a dict: sample0, sample1, etc.
# Collect energy arrays where label == 0 (capture scenarios)
filtered_data = [entry['energy'] for entry in dataset.values() if entry['label'] == 0]
# Downselect to 1500 samples
filtered_data = filtered_data[:1500]

# Convert to numpy array
data = np.array(filtered_data)

n_data, k_data = data.shape # define number of samples and features

x_trn = torch.from_numpy(data).type(torch.FloatTensor).to(Dev)
x_trn_cpu = x_trn.cpu()

print('Loaded {} samples with label 0.'.format(len(filtered_data)))

# Loop over latent dimensions, hidden dimensions, and betas
numCases = len(LatentDims) * len(HiddenDimsList) * len(Betas)
print('Number of cases: ', numCases)

sum_time = 0
complete = 0

# Loop over all combinations of latent dimensions, hidden dimensions, and betas
for LatentDim in LatentDims:
    for HiddenDimsArr in HiddenDimsList:
        for Beta in Betas:
            tic = time.time()
            HiddenDims = HiddenDimsArr.tolist()
            print(f"LD: {LatentDim}, HD: {HiddenDims}, Beta: {Beta}")

            #%% create directory for saving figures
            # Create a directory for the current configuration
            suffix = f'{tag}_{DataType}_dim{LatentDim}_'
            for i, d in enumerate(HiddenDims):
                suffix += str(HiddenDims[i])
                if i < len(HiddenDims) - 1:
                    suffix += 'x'
            suffix += f'_beta{Beta}_batch{BatchSize}_epochs{Epochs}_{date}'


            #%% train
            vae_model = VAE(k_data,
                            latent_dim = LatentDim, # define the latent dimension of the auto-encoder
                            hidden_dims = HiddenDims, # define the strucute of the encoder and decoder
                            beta = Beta,
                            dev = Dev
                        )
            vae_model.to(Dev)
            epoch_loss, rec_loss, kld_loss = vae_model.train(x_trn,
                                                        batch_size = BatchSize,
                                                        epochs = Epochs
                                                        ) # train the vae model using synthetic data
            

            #%% plots
            # total loss
            fig, ax = plt.subplots()
            ax.grid()
            ax.plot(epoch_loss, '.')
            ax.set_ylim(-1, 30)
            plt.savefig(os.path.join(figPath, f'loss_{suffix}.png'))

            # loss sources
            fig, ax = plt.subplots()
            ax.grid()
            ax.plot(rec_loss, '.', label = 'reconstruction')
            ax.plot(kld_loss, '.', label = 'kld')
            ax.legend()
            plt.savefig(os.path.join(figPath, f'lossTerms_{suffix}.png'))

            # compare variance
            big_samp = vae_model.sample(num_samples = n_data).cpu()
            fig, ax = plt.subplots()
            ax.grid()
            ax.plot(stats.describe(x_trn_cpu, axis = 0)[3], label = 'Training Variance')
            ax.plot(stats.describe(big_samp.detach(), axis = 0)[3], label = 'VAE Variance')
            ax.legend()
            plt.savefig(os.path.join(figPath, f'var_{suffix}.png'))
            plt.close('all')

            toc = time.time()
            run_time = toc - tic
            avg_time = (sum_time+run_time)/(complete+1)
            time_left = (numCases-complete)*avg_time
            sum_time += run_time
            complete += 1
            print(f"Estimated time left: {time_left/60:.2f} minutes")

Loaded 1500 samples with label 0.
Number of cases:  72
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [01:14<00:00, 40.19 Epoch/s, loss=0.0147]


Estimated time left: 89.78 minutes
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [01:16<00:00, 39.44 Epoch/s, loss=0.0153]


Estimated time left: 89.36 minutes
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [01:17<00:00, 38.60 Epoch/s, loss=0.0158]


Estimated time left: 89.02 minutes
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [01:14<00:00, 40.49 Epoch/s, loss=0.0164]


Estimated time left: 87.15 minutes
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [01:11<00:00, 41.73 Epoch/s, loss=0.0175]


Estimated time left: 85.03 minutes
LD: 6, HD: [64, 48, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [01:09<00:00, 43.03 Epoch/s, loss=0.0183]


Estimated time left: 82.93 minutes
LD: 6, HD: [64, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [00:56<00:00, 53.44 Epoch/s, loss=0.0145]


Estimated time left: 78.86 minutes
LD: 6, HD: [64, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [00:58<00:00, 51.52 Epoch/s, loss=0.0151]


Estimated time left: 75.86 minutes
LD: 6, HD: [64, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [00:56<00:00, 53.39 Epoch/s, loss=0.0158]


Estimated time left: 73.07 minutes
LD: 6, HD: [64, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [00:57<00:00, 52.20 Epoch/s, loss=0.017] 


Estimated time left: 70.78 minutes
LD: 6, HD: [64, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [00:57<00:00, 52.15 Epoch/s, loss=0.0176]


Estimated time left: 68.74 minutes
LD: 6, HD: [64, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [00:57<00:00, 51.94 Epoch/s, loss=0.0179]


Estimated time left: 66.90 minutes
LD: 6, HD: [32, 16, 8], Beta: 0.0013


100%|██████████| 3000/3000 [00:45<00:00, 65.92 Epoch/s, loss=0.0148]


Estimated time left: 64.25 minutes
LD: 6, HD: [32, 16, 8], Beta: 0.0014


100%|██████████| 3000/3000 [00:44<00:00, 67.50 Epoch/s, loss=0.0155]


Estimated time left: 61.80 minutes
LD: 6, HD: [32, 16, 8], Beta: 0.0015


100%|██████████| 3000/3000 [00:45<00:00, 66.41 Epoch/s, loss=0.0158]


Estimated time left: 59.62 minutes
LD: 6, HD: [32, 16, 8], Beta: 0.0016


100%|██████████| 3000/3000 [00:44<00:00, 67.25 Epoch/s, loss=0.017] 


Estimated time left: 57.59 minutes
LD: 6, HD: [32, 16, 8], Beta: 0.0017


100%|██████████| 3000/3000 [00:44<00:00, 67.18 Epoch/s, loss=0.0176]


Estimated time left: 55.71 minutes
LD: 6, HD: [32, 16, 8], Beta: 0.0018


100%|██████████| 3000/3000 [00:45<00:00, 65.30 Epoch/s, loss=0.0183]


Estimated time left: 54.02 minutes
LD: 6, HD: [48, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [01:03<00:00, 47.55 Epoch/s, loss=0.0145]


Estimated time left: 53.26 minutes
LD: 6, HD: [48, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [01:02<00:00, 47.80 Epoch/s, loss=0.015] 


Estimated time left: 52.44 minutes
LD: 6, HD: [48, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [00:56<00:00, 53.17 Epoch/s, loss=0.0158]


Estimated time left: 51.33 minutes
LD: 6, HD: [48, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [00:54<00:00, 54.69 Epoch/s, loss=0.0167]


Estimated time left: 50.18 minutes
LD: 6, HD: [48, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [00:54<00:00, 54.77 Epoch/s, loss=0.0171]


Estimated time left: 49.05 minutes
LD: 6, HD: [48, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [00:57<00:00, 52.60 Epoch/s, loss=0.0178]


Estimated time left: 48.01 minutes
LD: 7, HD: [64, 48, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [01:09<00:00, 43.28 Epoch/s, loss=0.0144]


Estimated time left: 47.37 minutes
LD: 7, HD: [64, 48, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [01:10<00:00, 42.35 Epoch/s, loss=0.0148]


Estimated time left: 46.74 minutes
LD: 7, HD: [64, 48, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [01:09<00:00, 43.41 Epoch/s, loss=0.0158]


Estimated time left: 46.01 minutes
LD: 7, HD: [64, 48, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [01:06<00:00, 45.00 Epoch/s, loss=0.0164]


Estimated time left: 45.19 minutes
LD: 7, HD: [64, 48, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [01:06<00:00, 45.04 Epoch/s, loss=0.0173]


Estimated time left: 44.35 minutes
LD: 7, HD: [64, 48, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [01:16<00:00, 39.40 Epoch/s, loss=0.0176]


Estimated time left: 43.72 minutes
LD: 7, HD: [64, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [00:57<00:00, 52.56 Epoch/s, loss=0.0145]


Estimated time left: 42.62 minutes
LD: 7, HD: [64, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [00:57<00:00, 52.39 Epoch/s, loss=0.0148]


Estimated time left: 41.53 minutes
LD: 7, HD: [64, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [00:55<00:00, 54.50 Epoch/s, loss=0.016] 


Estimated time left: 40.41 minutes
LD: 7, HD: [64, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [00:54<00:00, 54.55 Epoch/s, loss=0.0163]


Estimated time left: 39.30 minutes
LD: 7, HD: [64, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [00:57<00:00, 52.49 Epoch/s, loss=0.0171]


Estimated time left: 38.23 minutes
LD: 7, HD: [64, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [00:56<00:00, 52.76 Epoch/s, loss=0.0178]


Estimated time left: 37.17 minutes
LD: 7, HD: [32, 16, 8], Beta: 0.0013


100%|██████████| 3000/3000 [00:44<00:00, 67.41 Epoch/s, loss=0.0148]


Estimated time left: 35.91 minutes
LD: 7, HD: [32, 16, 8], Beta: 0.0014


100%|██████████| 3000/3000 [00:42<00:00, 70.45 Epoch/s, loss=0.0156]


Estimated time left: 34.65 minutes
LD: 7, HD: [32, 16, 8], Beta: 0.0015


100%|██████████| 3000/3000 [00:43<00:00, 69.06 Epoch/s, loss=0.0158]


Estimated time left: 33.43 minutes
LD: 7, HD: [32, 16, 8], Beta: 0.0016


100%|██████████| 3000/3000 [00:43<00:00, 68.72 Epoch/s, loss=0.0166]


Estimated time left: 32.24 minutes
LD: 7, HD: [32, 16, 8], Beta: 0.0017


100%|██████████| 3000/3000 [00:43<00:00, 68.34 Epoch/s, loss=0.0176]


Estimated time left: 31.07 minutes
LD: 7, HD: [32, 16, 8], Beta: 0.0018


100%|██████████| 3000/3000 [00:44<00:00, 67.78 Epoch/s, loss=0.0178]


Estimated time left: 29.93 minutes
LD: 7, HD: [48, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [00:55<00:00, 54.21 Epoch/s, loss=0.0144]


Estimated time left: 28.93 minutes
LD: 7, HD: [48, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [00:53<00:00, 55.84 Epoch/s, loss=0.0152]


Estimated time left: 27.93 minutes
LD: 7, HD: [48, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [00:53<00:00, 55.62 Epoch/s, loss=0.0156]


Estimated time left: 26.92 minutes
LD: 7, HD: [48, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [00:56<00:00, 53.42 Epoch/s, loss=0.0165]


Estimated time left: 25.95 minutes
LD: 7, HD: [48, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [00:56<00:00, 53.41 Epoch/s, loss=0.0171]


Estimated time left: 24.97 minutes
LD: 7, HD: [48, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [00:57<00:00, 52.53 Epoch/s, loss=0.0177]


Estimated time left: 24.01 minutes
LD: 8, HD: [64, 48, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [01:11<00:00, 42.15 Epoch/s, loss=0.0148]


Estimated time left: 23.16 minutes
LD: 8, HD: [64, 48, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [01:11<00:00, 42.17 Epoch/s, loss=0.015] 


Estimated time left: 22.30 minutes
LD: 8, HD: [64, 48, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [01:09<00:00, 43.22 Epoch/s, loss=0.0156]


Estimated time left: 21.41 minutes
LD: 8, HD: [64, 48, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [01:07<00:00, 44.26 Epoch/s, loss=0.0162]


Estimated time left: 20.50 minutes
LD: 8, HD: [64, 48, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [01:08<00:00, 43.87 Epoch/s, loss=0.0171]


Estimated time left: 19.59 minutes
LD: 8, HD: [64, 48, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [01:08<00:00, 43.88 Epoch/s, loss=0.0181]


Estimated time left: 18.67 minutes
LD: 8, HD: [64, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [00:55<00:00, 53.94 Epoch/s, loss=0.0145]


Estimated time left: 17.67 minutes
LD: 8, HD: [64, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [00:56<00:00, 53.44 Epoch/s, loss=0.0151]


Estimated time left: 16.67 minutes
LD: 8, HD: [64, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [00:57<00:00, 51.77 Epoch/s, loss=0.0157]


Estimated time left: 15.69 minutes
LD: 8, HD: [64, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [00:57<00:00, 51.82 Epoch/s, loss=0.0164]


Estimated time left: 14.70 minutes
LD: 8, HD: [64, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [00:57<00:00, 52.21 Epoch/s, loss=0.0172]


Estimated time left: 13.72 minutes
LD: 8, HD: [64, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [00:56<00:00, 53.00 Epoch/s, loss=0.0175]


Estimated time left: 12.73 minutes
LD: 8, HD: [32, 16, 8], Beta: 0.0013


100%|██████████| 3000/3000 [00:45<00:00, 65.45 Epoch/s, loss=0.0143]


Estimated time left: 11.71 minutes
LD: 8, HD: [32, 16, 8], Beta: 0.0014


100%|██████████| 3000/3000 [00:45<00:00, 66.15 Epoch/s, loss=0.0153]


Estimated time left: 10.70 minutes
LD: 8, HD: [32, 16, 8], Beta: 0.0015


100%|██████████| 3000/3000 [00:45<00:00, 66.35 Epoch/s, loss=0.0159]


Estimated time left: 9.69 minutes
LD: 8, HD: [32, 16, 8], Beta: 0.0016


100%|██████████| 3000/3000 [00:44<00:00, 67.55 Epoch/s, loss=0.0165]


Estimated time left: 8.69 minutes
LD: 8, HD: [32, 16, 8], Beta: 0.0017


100%|██████████| 3000/3000 [00:44<00:00, 67.48 Epoch/s, loss=0.0171]


Estimated time left: 7.70 minutes
LD: 8, HD: [32, 16, 8], Beta: 0.0018


100%|██████████| 3000/3000 [00:44<00:00, 66.90 Epoch/s, loss=0.0176]


Estimated time left: 6.71 minutes
LD: 8, HD: [48, 32, 16], Beta: 0.0013


100%|██████████| 3000/3000 [00:54<00:00, 54.72 Epoch/s, loss=0.0144]


Estimated time left: 5.75 minutes
LD: 8, HD: [48, 32, 16], Beta: 0.0014


100%|██████████| 3000/3000 [00:54<00:00, 54.84 Epoch/s, loss=0.0151]


Estimated time left: 4.79 minutes
LD: 8, HD: [48, 32, 16], Beta: 0.0015


100%|██████████| 3000/3000 [00:55<00:00, 54.29 Epoch/s, loss=0.0159]


Estimated time left: 3.83 minutes
LD: 8, HD: [48, 32, 16], Beta: 0.0016


100%|██████████| 3000/3000 [00:55<00:00, 54.52 Epoch/s, loss=0.0166]


Estimated time left: 2.87 minutes
LD: 8, HD: [48, 32, 16], Beta: 0.0017


100%|██████████| 3000/3000 [00:55<00:00, 53.87 Epoch/s, loss=0.0168]


Estimated time left: 1.91 minutes
LD: 8, HD: [48, 32, 16], Beta: 0.0018


100%|██████████| 3000/3000 [00:55<00:00, 53.94 Epoch/s, loss=0.0173]


Estimated time left: 0.96 minutes
