# Imports

TO RUN:
- Make sure to select Runtime>Change Runtime Type>T4 GPU to use cuda
- Install pytorch_lightning
- clone in repo to read in common files

In [3]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.7.2-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.1.0->pytorch_lightning)
  Dow

In [4]:
import os
import random
import numpy as np
import torch
from pytorch_lightning import seed_everything

SEED = 42
seed_everything(SEED, workers=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import torch.optim as optim
from torch.distributions import Normal
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pathlib
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
import seaborn as sns
import json
from tqdm.notebook import tqdm_notebook
from time import time
import datetime
import pickle
import matplotlib.lines as mlines
import sys
from sklearn.decomposition import PCA

INFO:lightning_fabric.utilities.seed:Seed set to 42


In [5]:
!git clone https://gcalkins64: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@github.com/gcalkins64/pipag_training.git
!cd pipag_training && git pull
sys.path.append('/content/pipag_training')

Cloning into 'pipag_training'...
remote: Enumerating objects: 291, done.[K
remote: Counting objects: 100% (107/107), done.[K
remote: Compressing objects: 100% (77/77), done.[K
remote: Total 291 (delta 59), reused 59 (delta 30), pack-reused 184 (from 1)[K
Receiving objects: 100% (291/291), 7.16 MiB | 8.28 MiB/s, done.
Resolving deltas: 100% (148/148), done.
Already up to date.


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
from gmvae_common import *

# Check Devices

In [8]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set default tensor type to CUDA tensors
torch.set_default_tensor_type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor)
print(device)

cuda


  _C._set_default_tensor_type(t)


# Settings

In [9]:
n_train = 1024
n_val = 128
n_test = 128
# For testing small batches of data
# n_train = 400
# n_val = 50
# n_test = 50
# THREE LAYERS
# hd1 = 48
# hd2 = 32
# hd3 = 16
# hidden_dims = [hd1, hd2, hd3]
# TWO LAYERS
hd1 = 32
hd2 = 16
hd3 = None
hidden_dims = [hd1, hd2]

latent_dims = [4,5,6]
n_clustersS = [2,3,4,5,6]
# skip_combos = [[4,2], [4,3], [4,4], [4,5]]
skip_combos = []
lr = 1e-3
n_epochs = 10_000 # 30_000
batch_size = 128
em_reg = 1e-6
decoder_var = 1e-5

plot_interval = 500
dpi = 300

D = 1  #num_modalities
downsampleNum = 64

loadFlag = False  # True to load in old case, False to run a new case

# data = '1_near_escape_fnpag_2000_data_energy_scaled_downsampled_'
# inds = '1_near_escape_fnpag_1999_inds_energy_scaled_downsampled_'
# tag = 'near_escape'

# data = '1_near_crash_fnpag_2000_data_energy_scaled_downsampled_'
# inds = '1_near_crash_fnpag_2000_inds_energy_scaled_downsampled_'
# tag = 'near_crash'

data = 'UOP_inc_lit_disps_5000_data_energy_scaled_downsampled_'
inds = 'UOP_inc_lit_disps_4999_inds_energy_scaled_downsampled_'
tag = 'near_escape_OLD'

# data = 'UOP_near_crash_steeper_5000_data_energy_scaled_downsampled_'
# inds = 'UOP_near_crash_steeper_4997_inds_energy_scaled_downsampled_'
# tag = 'near_crash_OLD'

# LOOP GMVAE TRAINING

In [None]:
for latent_dim in latent_dims:
   for n_clusters in n_clustersS:
      if [latent_dim, n_clusters] in skip_combos:
        continue
      else:
          K = n_clusters
          Z = latent_dim

          from datetime import datetime
          timestr = datetime.strftime(datetime.now(), "%Y%m%d_%H%M%S")

          if loadFlag:  # Load in old data
              dirname = os.path.join("drive", "MyDrive", "JP_gmvae_data", 'gmvae_em_aerocapture_energy_20250429_155508_5_4')
          else:  # generate new data
              dirname = os.path.join("drive", "MyDrive", "JP_gmvae_data", "gmvae_"+tag+"_"+timestr+"_L"+str(latent_dim)+"_C"+str(n_clusters))
          os.makedirs(dirname, exist_ok=True)
          print("Filepath directory: " + dirname)

          if loadFlag:  # load old data
              postfix = '_42_1024_128_128_32_16_5_4_0.001000_128_0.001000_0.000010_1000.000000_'
          else:  # Generate new suffix
              if hd3 is not None:
                  postfix = '_{0:d}_{1:d}_{2:d}_{3:d}_{4:d}_{5:d}_{6:d}_{7:d}_{8:d}_{9:f}_{10:d}_{11:f}_{12:f}_{13:f}_'.format(
                  SEED, n_train, n_val, n_test, hd1, hd2, hd3, latent_dim, n_clusters, lr, batch_size, em_reg * 1e3, decoder_var, n_epochs)
              else:
                  postfix = '_{0:d}_{1:d}_{2:d}_{3:d}_{4:d}_{5:d}_{6:d}_{7:d}_{8:f}_{9:d}_{10:f}_{11:f}_{12:f}_'.format(
                  SEED, n_train, n_val, n_test, hd1, hd2, latent_dim, n_clusters, lr, batch_size, em_reg * 1e3, decoder_var, n_epochs)
          print("Filepath postfix: " + postfix)

          data_dir = os.path.join("drive", "MyDrive", "JP_gmvae_data", f"{data}.json")

          with open(os.path.join("drive", "MyDrive", "JP_gmvae_data", f"{inds}.json"), 'r') as f:
              sample_list_load = json.load(f)
          sample_list = [int(sample) for sample in sample_list_load['sample_list']]
          print(sample_list)

          data_module = AerocaptureDataModuleCUDA(data_dir=data_dir, n_train=n_train, n_val=n_val, n_test=n_test,
                                          train_batch=batch_size, val_batch=batch_size, test_batch=batch_size,
                                          num_workers=0)

          data_module.setup("fit", sample_list = sample_list)

          train_loader = data_module.train_dataloader()
          val_loader = data_module.val_dataloader()
          test_loader = data_module.test_dataloader()
          data_dim = len(train_loader.dataset[0][0])
          text_labels = ['capture', 'escape', 'impact']
          label_colors = ['C2', 'C3', 'C4']

          num_train_batches = len(train_loader)

          ts_plot = np.linspace(0,450,64)
          seabornSettings()
          fig, ax = plt.subplots(figsize=(4, 4))
          for j in range(n_train):
              ax.plot(ts_plot, train_loader.dataset[j][0].cpu(), color=label_colors[train_loader.dataset[j][1].cpu()], alpha=0.75)

          eline = mlines.Line2D([], [], color='C2', label='Escape')
          cline = mlines.Line2D([], [], color='C3', label='Capture')
          iline = mlines.Line2D([], [], color='C4', label='Impact')
          plt.legend(handles=[eline, cline, iline])
          plt.hlines(0, 0, ts_plot[-1], colors='r', linestyles='dashed')
          plt.xlabel("Time [s]")
          plt.ylabel("Nondimensionalized Energy")
          plt.title("Training Data")
          plt.tight_layout()
          fig.savefig(os.path.join(dirname, 'train_data'+postfix+'.png'), dpi=dpi)

          fig, ax = plt.subplots(figsize=(4.5, 4))
          for j in range(n_val):
              ax.plot(ts_plot, val_loader.dataset[j][0].cpu(), color=label_colors[val_loader.dataset[j][1].cpu()])
          plt.title("Validation Data")
          fig.savefig(os.path.join(dirname, 'val_data'+postfix+'.png'), dpi=dpi)

          fig, ax = plt.subplots(figsize=(4.5, 4))
          for j in range(n_test):
              ax.plot(ts_plot, test_loader.dataset[j][0].cpu(), color=label_colors[test_loader.dataset[j][1].cpu()])
          plt.title("Test Data")
          fig.savefig(os.path.join(dirname, 'test_data'+postfix+'.png'), dpi=dpi)
          # initialize latent GMM model parameters
          params = {}
          pi_variables = torch.zeros(K).clone().detach().requires_grad_(True)
          params['pi_c'] = torch.ones(K) / K
          params['mu_c'] = torch.rand((K, Z)) * 2.0 - 1.0
          params['logsigmasq_c'] = torch.zeros((K, Z))


          text_labels = [f'Cluster {i}' for i in range((n_clusters))]
          label_colors = [f'C{i+1}' for i in range((n_clusters))]

          # initialize neural networks
          encoder_list = []
          decoder_list = []
          trainable_parameters = []
          trainable_parameters.append(pi_variables)

          for _ in range(D):
              encoder = Encoder(data_dim=data_dim, latent_dim=latent_dim, hidden_dims=[hd1, hd2])
              decoder = Decoder(data_dim=data_dim, latent_dim=latent_dim, hidden_dims=[hd2, hd1], decoder_var=decoder_var)
              encoder_list.append(encoder)
              decoder_list.append(decoder)
              trainable_parameters += list(encoder.parameters()) + list(decoder.parameters())

          optimizer = optim.Adam(trainable_parameters, lr=lr)

          # training

          import time
          ts = time.time()
          tic = time.perf_counter()

          train_loss = torch.zeros(n_epochs)
          train_elbo_terms = torch.zeros((n_epochs, 4)) # 4 ELBO terms
          val_elbo_terms = torch.zeros((n_epochs, 4)) # 4 ELBO terms
          val_loss = torch.zeros(n_epochs)
          pi_history = torch.zeros((n_epochs, K))
          train_mse_history = torch.zeros(n_epochs)
          val_mse_history = torch.zeros(n_epochs)
          min_val_loss = torch.inf
          seabornSettings()

          for epoch in range(n_epochs):
              ti = time.time()
              for encoder in encoder_list:
                  encoder.train()
              for decoder in decoder_list:
                  decoder.train()

              train_elbo = 0
              train_mse = 0
              train_elbo_term = np.zeros(4)
              params['hist_weights'] = torch.zeros((K, 1))
              params['hist_mu_c'] = torch.zeros((K, latent_dim))
              params['hist_logsigmasq_c'] = torch.zeros((K, latent_dim))

              for (batch_idx, batch) in enumerate(train_loader):
                  batch_x, _ = batch
                  x_list = [batch_x]  # assume D=2 and each modality has data_dim
                  optimizer.zero_grad()
                  pi_c = torch.exp(pi_variables) / torch.sum(torch.exp(pi_variables))
                  params['pi_c'] = pi_c

                  mu, logsigmasq = encoder_step(x_list, encoder_list, decoder_list)
                  sigma = torch.exp(0.5 * logsigmasq)
                  eps = Normal(0, 1).sample(mu.shape)
                  z = mu + eps * sigma

                  with torch.no_grad():
                      gamma_c, mu_c, logsigmasq_c = em_step(z, mu, logsigmasq, params, em_reg, update_by_batch=True)
                  params['mu_c'] = mu_c
                  params['logsigmasq_c'] = logsigmasq_c

                  elbo, sse, elbo_terms = decoder_step(x_list, z, encoder_list, decoder_list, params, mu, logsigmasq, gamma_c)
                  train_elbo += elbo.item()
                  train_elbo_term += elbo_terms
                  train_mse += sse.item()
                  loss = - elbo / batch_x.shape[0]
                  loss.backward()
                  optimizer.step()

              for encoder in encoder_list:
                  encoder.eval()
              for decoder in decoder_list:
                  decoder.eval()

              if epoch % plot_interval == 0 or epoch == n_epochs:
                  # Plot the first two dimensions of the latents
                  with torch.no_grad():
                      means = []
                      samples = []
                      labels = []
                      for batch in train_loader:
                          batch_x, batch_label = batch
                          x_list = [batch_x]
                          mean, logsigmasq = encoder_step(x_list, encoder_list, decoder_list)
                          sigma = torch.exp(0.5 * logsigmasq)
                          eps = Normal(0, 1).sample(mean.shape)
                          z = mean + eps * sigma
                          means.append(mean)
                          samples.append(z)
                          labels.append(batch_label)

                  means = torch.vstack(means).cpu()
                  samples = torch.vstack(samples).cpu()
                  labels = torch.hstack(labels).cpu()

                  savepath = os.path.join(dirname, "latent_samples_epoch_" + str(epoch) + postfix)
                  plot_latent_space_with_clusters(samples, labels, K, mu_c.cpu(), logsigmasq_c.cpu(), savepath, text_labels, label_colors, label_colors, epoch, dpi=dpi)

                  savepath = os.path.join(dirname, "latent_means_epoch_" + str(epoch) + postfix)
                  plot_latent_space_with_clusters(means, labels, K, mu_c.cpu(), logsigmasq_c.cpu(), savepath, text_labels, label_colors, label_colors, epoch, dpi=dpi)


                  # plot samples from generative model
                  n_gen = n_train
                  cluster_probs = params['pi_c'].cpu().detach().numpy() #
                  fig, ax = plt.subplots(figsize=(4.5, 4))
                  for j in range(n_gen):
                      c = np.random.choice(K, p=cluster_probs)
                      mu_c = params['mu_c'][c].clone().detach()
                      sigma_c = torch.exp(0.5 * params['logsigmasq_c'][c]).clone().detach()
                      z = Normal(0, 1).sample(mu_c.shape) * sigma_c + mu_c
                      mu_x = decoder.forward(z)[0]
                      ax.plot(mu_x.cpu().detach().numpy())
                  fig.savefig(os.path.join(dirname, "generate_samples_" + str(epoch) + postfix+ '.png'), dpi=dpi)
                  plt.close()


              val_elbo = 0
              val_mse = 0
              val_elbo_term = np.zeros(4)
              with torch.no_grad():
                  for (batch_idx, batch) in enumerate(val_loader):
                      batch_x, _ = batch
                      x_list = [batch_x]
                      mu, logsigmasq = encoder_step(x_list, encoder_list, decoder_list)
                      sigma = torch.exp(0.5 * logsigmasq)
                      eps = Normal(0, 1).sample(mu.shape)
                      z = mu + eps * sigma
                      with torch.no_grad():
                          gamma_c, _, _ = em_step(z, mu, logsigmasq, params, em_reg)
                      elbo, sse, elbo_items = decoder_step(x_list, z, encoder_list, decoder_list, params, mu, logsigmasq, gamma_c)
                      val_elbo += elbo.item()
                      val_mse += sse.item()
                      val_elbo_term += elbo_items

              train_elbo /= len(train_loader.dataset)
              train_elbo_term = torch.tensor(train_elbo_term) / len(train_loader.dataset)
              val_elbo /= len(val_loader.dataset)
              val_elbo_term = torch.tensor(val_elbo_term) / len(val_loader.dataset)
              train_mse /= len(train_loader.dataset)
              val_mse /= len(val_loader.dataset)

              tf = time.time()
              toc = time.perf_counter()
              print('====> Epoch: {} Train ELBO: {:.4f} Val ELBO: {:.4f}, Epoch Time (s): {:.2f}, Total Time (hrs): {:.4f}'.format(epoch, train_elbo, val_elbo, tf-ti, (toc-tic)/60/60))

              train_loss[epoch] = - train_elbo
              val_loss[epoch] = - val_elbo
              train_elbo_terms[epoch,:] = - train_elbo_term
              val_elbo_terms[epoch,:] = - val_elbo_term
              pi_history[epoch] = params['pi_c']
              train_mse_history[epoch] = train_mse
              val_mse_history[epoch] = val_mse

              if - val_elbo < min_val_loss:
                  min_val_loss = - val_elbo
                  torch.save(params['pi_c'], os.path.join(dirname, 'gmm_params_pi'+ postfix + '.pt'))
                  torch.save(params['mu_c'], os.path.join(dirname, 'gmm_params_mu'+ postfix + '.pt'))
                  torch.save(params['logsigmasq_c'], os.path.join(dirname, 'gmm_params_logsigmasq'+ postfix + '.pt'))
                  torch.save(encoder.state_dict(), os.path.join(dirname, 'encoder'+ postfix + '.pt'))
                  torch.save(decoder.state_dict(), os.path.join(dirname, 'decoder'+ postfix + '.pt'))

              if epoch % plot_interval == 0 or epoch == n_epochs:
                # Plot the training and validation loss vs. epoch number
                plt.figure(figsize=(4.5, 4))
                # const = min(min(train_loss), min(val_loss))
                train_loss_adjusted = train_loss
                val_loss_adjusted = val_loss
                plt.plot(train_loss_adjusted.cpu()[:epoch], label='train')
                # print(train_loss_adjusted.cpu()[:epoch])
                plt.plot(val_loss_adjusted.cpu()[:epoch], label='val')
                plt.yscale('symlog')
                plt.xlabel("number of epochs")
                plt.ylabel("loss")
                plt.title("Negative Loss")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(dirname, 'elbo_'+ str(epoch)+postfix+'.png'), dpi=dpi)
                plt.close()

                # Plot each term of the training loss and validation loss
                plt.figure(figsize=(4.5, 4))
                labels = ["Reconstruction", "GMM Reg", "Prob Reg", "Encoder Var"]
                for ii in range(4):
                    train_loss_adjusted = train_elbo_terms[:epoch, ii]
                    val_loss_adjusted = val_elbo_terms[:epoch, ii]
                    plt.plot(train_loss_adjusted.cpu()[:epoch], label=f"{labels[ii]}: Train")
                    plt.plot(val_loss_adjusted.cpu()[:epoch], label=f"{labels[ii]}: Val", linestyle='--')
                plt.xlabel("number of epochs")
                plt.yscale('symlog')
                plt.ylabel("loss")
                plt.title("Negative Loss Terms")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(dirname, 'elbo_terms_'+ str(epoch)+postfix+'.png'), dpi=dpi)
                plt.close()

                # Plot the training and validation mse vs. epoch number
                plt.figure(figsize=(4.5, 4))
                plt.semilogy(train_mse_history.cpu().detach().numpy()[:epoch], label='train')
                plt.semilogy(val_mse_history.cpu().detach().numpy()[:epoch], label='val')
                plt.xlabel("number of epochs")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(dirname, 'reconst_mse_'+ str(epoch)+postfix+'.png'), dpi=dpi)
                plt.close()

                # Plot the history of pi
                plt.figure(figsize=(4.5, 4))
                for i in range(K):
                    plt.plot(pi_history[:, i].cpu().detach().numpy()[:epoch], label=r'$\pi$' + str(i+1))
                plt.xlabel("number of epochs")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(dirname, 'pi_'+ str(epoch)+postfix+'.png'), dpi=dpi)
                plt.close()


          te = time.time()
          import datetime
          duration = datetime.timedelta(seconds=te - ts)
          print("Training took ", duration)

          # Save off pi_history, train_loss, val_loss, train_mse_history, val_mse_history
          torch.save(pi_history, os.path.join(dirname, 'pi_history'+ postfix + '.pt'))
          torch.save(train_loss, os.path.join(dirname, 'train_loss'+ postfix + '.pt'))
          torch.save(val_loss, os.path.join(dirname, 'val_loss'+ postfix + '.pt'))
          torch.save(train_elbo_terms, os.path.join(dirname, 'train_elbo_terms'+ postfix + '.pt'))
          torch.save(val_elbo_terms, os.path.join(dirname, 'val_elbo_terms'+ postfix + '.pt'))
          torch.save(train_mse_history, os.path.join(dirname, 'train_mse_history'+ postfix + '.pt'))
          torch.save(val_mse_history, os.path.join(dirname, 'val_mse_history'+ postfix + '.pt'))

          if loadFlag:
            path = dirname
            epoch = n_epochs - 1
          params = {}
          encoder = Encoder(data_dim=data_dim, latent_dim=latent_dim, hidden_dims=[hd1, hd2]).to("cuda")
          decoder = Decoder(data_dim=data_dim, latent_dim=latent_dim, hidden_dims=[hd2, hd1], decoder_var=decoder_var).to("cuda")

          # Use if you need to load in the data
          if loadFlag:
            suffix = postfix
            encoder.load_state_dict(torch.load(os.path.join(path, f'encoder_{suffix}.pt'),map_location=torch.device('cpu')))
            decoder.load_state_dict(torch.load(os.path.join(path, f'decoder_{suffix}.pt'),map_location=torch.device('cpu')))
            logsigmasq = torch.load(os.path.join(path, f'gmm_params_logsigmasq_{suffix}.pt'),map_location=torch.device('cpu'))
            mu = torch.load(os.path.join(path, f'gmm_params_mu_{suffix}.pt'),map_location=torch.device('cpu'))
            pi = torch.load(os.path.join(path, f'gmm_params_pi_{suffix}.pt'),map_location=torch.device('cpu'))

          encoder_list = [encoder]
          decoder_list = [decoder]

          device = next(encoder.parameters()).device
          # Load in training history metrics
          pi_history = torch.load(os.path.join(dirname, 'pi_history'+ postfix + '.pt'))
          train_loss = torch.load(os.path.join(dirname, 'train_loss'+ postfix + '.pt'))
          val_loss = torch.load(os.path.join(dirname, 'val_loss'+ postfix + '.pt'))
          train_elbo_terms = torch.load(os.path.join(dirname, 'train_elbo_terms'+ postfix + '.pt'))
          val_elbo_terms = torch.load(os.path.join(dirname, 'val_elbo_terms'+ postfix + '.pt'))
          train_mse_history = torch.load(os.path.join(dirname, 'train_mse_history'+ postfix + '.pt'))
          val_mse_history = torch.load(os.path.join(dirname, 'val_mse_history'+ postfix + '.pt'))


          text_labels = [f'Cluster {i}' for i in range((n_clusters))]
          label_colors = [f'C{i+1}' for i in range((n_clusters))]
          data_colors = label_colors

          # Plot training history
          # Plot the training and validation loss vs. epoch number
          plt.figure(figsize=(4, 4))
          const = min(min(train_loss), min(val_loss))
          # const = min(10, const)
          train_loss_adjusted = train_loss
          val_loss_adjusted = val_loss
          plt.plot(train_loss_adjusted.cpu(), label='Training')
          plt.plot(val_loss_adjusted.cpu(), label='Validation')
          plt.xlabel("Number of Epochs")
          plt.ylabel("Negative Loss")
          # plt.title("Negative Loss")
          plt.yscale('symlog')
          plt.legend()
          plt.tight_layout()
          plt.savefig(os.path.join(dirname, 'elbo_'+ str(epoch)+postfix+'.png'), dpi=dpi)
          # plt.close()


          # Plot each term of the training loss and validation loss
          plt.figure(figsize=(4.5, 4))
          labels = ["Reconstruction", "GMM Reg", "Prob Reg", "Encoder Var"]
          for ii in range(4):
              # print(train_elbo_terms.cpu()[ii, :epoch])
              # print(len(train_elbo_terms.cpu()[ii, :epoch]))
              train_loss_adjusted = train_elbo_terms[:, ii]
              val_loss_adjusted = val_elbo_terms[:, ii]
              plt.plot(train_loss_adjusted.cpu()[:], label=f"{labels[ii]}: Train")
              plt.plot(val_loss_adjusted.cpu()[:], label=f"{labels[ii]}: Val", linestyle='--')
          plt.xlabel("number of epochs")
          plt.ylabel("loss")
          plt.title("Negative Loss Terms")
          plt.yscale('symlog')
          plt.legend()
          plt.tight_layout()
          plt.savefig(os.path.join(dirname, 'elbo_terms_'+ str(epoch)+postfix+'.png'), dpi=dpi)
          # plt.close()


          # Plot the sum of each term of the training loss and validation loss
          plt.figure(figsize=(4.5, 4))
          # print(train_elbo_terms.cpu()[ii, :epoch])
          # print(len(train_elbo_terms.cpu()[ii, :epoch]))
          train_loss_adjusted = train_elbo_terms.sum(axis=1)
          val_loss_adjusted = val_elbo_terms.sum(axis=1)
          plt.plot(train_loss_adjusted.cpu()[:], label=f"Train")
          plt.plot(val_loss_adjusted.cpu()[:], label=f"Val", linestyle='--')
          plt.xlabel("number of epochs")
          plt.ylabel("loss")
          plt.title("Sum of ELBO Terms")
          plt.yscale('symlog')
          plt.legend()
          plt.tight_layout()
          plt.savefig(os.path.join(dirname, 'elbo_terms_sum_'+ str(epoch)+postfix+'.png'), dpi=dpi)
          # plt.close()

          # Plot the training and validation mse vs. epoch number
          plt.figure(figsize=(4.5, 4))
          plt.semilogy(train_mse_history.cpu().detach().numpy(), label='train')
          plt.semilogy(val_mse_history.cpu().detach().numpy(), label='val')
          plt.xlabel("number of epochs")
          plt.legend()
          plt.tight_layout()
          plt.savefig(os.path.join(dirname, 'reconst_mse_'+ str(epoch)+postfix+'.png'), dpi=dpi)
          # plt.close()

          # Plot the history of pi
          plt.figure(figsize=(4, 4))
          for i in range(K):
              plt.plot(pi_history[:, i].cpu().detach().numpy(), label=text_labels[i]+r' $\pi$', color=label_colors[i])
          plt.xlabel("Number of Epochs")
          plt.ylabel("Predicted Cluster Probability")
          plt.axhline(y=0.05, color='C1', linestyle='--', label='True Escape Probability')
          plt.axhline(y=0.95, color='C3', linestyle='--', label='True Capture Probability')
          plt.legend()
          plt.tight_layout()
          plt.savefig(os.path.join(dirname, 'pi_'+ str(epoch)+postfix+'.png'), dpi=dpi)
          # plt.close()

          # Load best model saved
          params['pi_c'] = torch.load(os.path.join(dirname, 'gmm_params_pi'+ postfix + '.pt'))
          params['mu_c'] = torch.load(os.path.join(dirname, 'gmm_params_mu'+ postfix + '.pt'))
          params['logsigmasq_c'] = torch.load(os.path.join(dirname, 'gmm_params_logsigmasq'+ postfix + '.pt'))
          encoder.load_state_dict(torch.load(os.path.join(dirname, 'encoder'+ postfix + '.pt')))
          decoder.load_state_dict(torch.load(os.path.join(dirname, 'decoder'+ postfix + '.pt')))

          encoder.eval()
          decoder.eval()

          # run one last EM step and plot training data in latent space
          for encoder in encoder_list:
              encoder.eval()

          with torch.no_grad():
              means = []
              samples = []
              labels = []
              params['hist_weights'] = torch.zeros((K, 1))
              params['hist_mu_c'] = torch.zeros((K, latent_dim))
              params['hist_logsigmasq_c'] = torch.zeros((K, latent_dim))
              for batch in train_loader:
                  batch_x, batch_label = batch
                  x_list = [batch_x]
                  mu, logsigmasq = encoder_step(x_list, encoder_list, decoder_list)
                  sigma = torch.exp(0.5 * logsigmasq)
                  eps = Normal(0, 1).sample(mu.shape)
                  z = mu + eps * sigma
                  with torch.no_grad():
                      gamma_c, mu_c, logsigmasq_c = em_step(z, mu, logsigmasq, params, em_reg, update_by_batch=True)
                  params['mu_c'] = mu_c
                  params['logsigmasq_c'] = logsigmasq_c

                  means.append(mu)
                  samples.append(z)
                  labels.append(batch_label)

          means = torch.vstack(means).cpu()
          samples = torch.vstack(samples).cpu()
          labels = torch.hstack(labels).cpu()


          savepath = os.path.join(dirname, "BEST_latent_samples"+postfix)
          plot_latent_space_with_clusters(samples, labels, K, mu_c.cpu(), logsigmasq_c.cpu(), savepath, text_labels, label_colors, data_colors, dpi=dpi)

          savepath = os.path.join(dirname, "BEST_latent_means"+postfix)
          plot_latent_space_with_clusters(means, labels, K, mu_c.cpu(), logsigmasq_c.cpu(), savepath, text_labels, label_colors, data_colors, dpi=dpi)

          # plot test data in latent space
          with torch.no_grad():
              test_means = []
              test_labels = []
              for batch in test_loader:
                  batch_x, batch_label = batch
                  x_list = [batch_x]
                  mean, _ = encoder_step(x_list, encoder_list, decoder_list)
                  test_means.append(mean)
                  test_labels.append(batch_label)

          test_means = torch.vstack(test_means).cpu()
          test_labels = torch.hstack(test_labels).cpu()


          savepath = os.path.join(dirname, "BEST_test_latent_samples"+postfix)
          plot_latent_space_with_clusters(test_means, test_labels, K, mu_c.cpu(), logsigmasq_c.cpu(), savepath, text_labels, label_colors, data_colors, dpi=dpi)

          # plot decoding results from cluster means # todo: expand this function for multi-modal data
          fig, ax = plt.subplots(figsize=(4.5, 4))
          for i in range(K):
              # with torch.no_grad:
              x_mean = decoder.forward(params['mu_c'][i])[0]
              ax.plot(x_mean.cpu().detach().numpy(), label="decoded $\mu$"+str(i+1))
          ax.legend()
          plt.title("Decoded Means")
          fig.savefig(os.path.join(dirname, "BEST_decoded_means"+postfix+'.png'), dpi=dpi)
          # plt.close()

          # plot samples from generative model
          n_gen = n_train
          cluster_probs = params['pi_c'].cpu().detach().numpy()
          fig, ax = plt.subplots(figsize=(4.5, 4))
          for j in range(n_gen):
              c = np.random.choice(K, p=cluster_probs)
              # print(c)
              mu_c = params['mu_c'][c].cpu().clone().detach()
              sigma_c = torch.exp(0.5 * params['logsigmasq_c'][c]).cpu().clone().detach()
              z = Normal(0, 1).sample(mu_c.shape).cpu().clone().detach() * sigma_c + mu_c
              # print(z)
              mu_x = decoder.forward(z.cuda())[0].cpu().clone().detach()
              sigma_x = torch.exp(0.5 * decoder.forward(z.cuda())[1])
              sample_x = Normal(0, 1).sample(mu_x.shape).cpu().clone().detach() * sigma_x.cpu().clone().detach() + mu_x
              ax.plot(sample_x.cpu().detach().numpy())
          plt.title("Generated Samples")
          fig.savefig(os.path.join(dirname, "BEST_generate_samples"+postfix+'.png'), dpi=dpi)
          plt.close()

          # np.savez(dirname + postfix, train_loss=train_loss.cpu().detach().numpy(), val_loss=val_loss.cpu().detach().numpy(),
              # train_mse=train_mse_history.cpu().detach().numpy(), val_mse=val_mse_history.cpu().detach().numpy(),
              # pi_history=pi_history.cpu().detach().numpy(),
              # cluster_probs=params['pi_c'].cpu().detach().numpy(),
              # cluster_means=params['mu_c'].cpu().detach().numpy(),
              # cluster_vars=torch.exp(params['logsigmasq_c']).cpu().detach().numpy())


          print("Training data size", n_train)
          print("Fraction of downward curves:", (torch.sum(labels == 0) / n_train).item())
          print("Cluster 1 probability:", cluster_probs.min().item())

Filepath directory: drive/MyDrive/JP_gmvae_data/gmvae_near_escape_OLD_20250529_201140_L4_C2
Filepath postfix: _42_1024_128_128_32_16_4_2_0.001000_128_0.001000_0.000010_10000.000000_
[4995, 4418, 2999, 4714, 2263, 4629, 2776, 1179, 932, 792, 1851, 1185, 1723, 4078, 4755, 4052, 2166, 3901, 4155, 759, 2415, 3668, 4837, 3788, 1546, 1132, 3258, 4493, 2441, 3183, 1213, 3287, 763, 1216, 348, 4107, 2914, 1881, 1355, 3743, 1513, 3789, 1937, 1660, 2634, 804, 1441, 1094, 1840, 2723, 3741, 3625, 1453, 2093, 1999, 4812, 4075, 988, 1621, 4197, 1412, 608, 1771, 4577, 4960, 4827, 4525, 2637, 2833, 2505, 4047, 2204, 2754, 636, 503, 3882, 3998, 2529, 4284, 4588, 2670, 4763, 3818, 3782, 874, 200, 4385, 3039, 1721, 2996, 3532, 3113, 79, 2279, 349, 325, 3091, 280, 694, 1727, 3085, 1642, 927, 4622, 2857, 3204, 2966, 1524, 3282, 1686, 871, 610, 1995, 2849, 2969, 573, 3339, 219, 1273, 799, 4921, 3651, 1914, 1512, 3581, 574, 3157, 2938, 2193, 3585, 2612, 501, 1573, 3988, 2863, 881, 1586, 1235, 1626, 3179, 3814

  0%|          | 0/1280 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5010 Train ELBO: 240.0665 Val ELBO: 137.1671, Epoch Time (s): 0.09, Total Time (hrs): 0.1487
====> Epoch: 5011 Train ELBO: 239.2188 Val ELBO: 121.3567, Epoch Time (s): 0.08, Total Time (hrs): 0.1487
====> Epoch: 5012 Train ELBO: 234.2228 Val ELBO: 135.7840, Epoch Time (s): 0.08, Total Time (hrs): 0.1487
====> Epoch: 5013 Train ELBO: 236.3299 Val ELBO: 150.5005, Epoch Time (s): 0.09, Total Time (hrs): 0.1487
====> Epoch: 5014 Train ELBO: 233.8980 Val ELBO: 118.3742, Epoch Time (s): 0.08, Total Time (hrs): 0.1488
====> Epoch: 5015 Train ELBO: 237.6293 Val ELBO: 121.1039, Epoch Time (s): 0.08, Total Time (hrs): 0.1488
====> Epoch: 5016 Train ELBO: 237.5626 Val ELBO: 127.4420, Epoch Time (s): 0.09, Total Time (hrs): 0.1488
====> Epoch: 5017 Train ELBO: 233.2710 Val ELBO: 138.0856, Epoch Time (s): 0.08, Total Time (hrs): 0.1488
====> Epoch: 5018 Train ELBO: 232.5949 Val ELBO: 110.9152, Epoch Time (s): 0.09, Total 

  0%|          | 0/1280 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5010 Train ELBO: 239.8839 Val ELBO: 118.0671, Epoch Time (s): 0.09, Total Time (hrs): 0.1488
====> Epoch: 5011 Train ELBO: 238.6671 Val ELBO: 111.2584, Epoch Time (s): 0.08, Total Time (hrs): 0.1489
====> Epoch: 5012 Train ELBO: 236.2518 Val ELBO: 110.7068, Epoch Time (s): 0.08, Total Time (hrs): 0.1489
====> Epoch: 5013 Train ELBO: 238.2634 Val ELBO: 119.0584, Epoch Time (s): 0.09, Total Time (hrs): 0.1489
====> Epoch: 5014 Train ELBO: 236.9395 Val ELBO: 116.9350, Epoch Time (s): 0.08, Total Time (hrs): 0.1489
====> Epoch: 5015 Train ELBO: 240.5651 Val ELBO: 116.2200, Epoch Time (s): 0.08, Total Time (hrs): 0.1490
====> Epoch: 5016 Train ELBO: 240.1441 Val ELBO: 112.6625, Epoch Time (s): 0.08, Total Time (hrs): 0.1490
====> Epoch: 5017 Train ELBO: 234.3427 Val ELBO: 111.3283, Epoch Time (s): 0.08, Total Time (hrs): 0.1490
====> Epoch: 5018 Train ELBO: 232.3660 Val ELBO: 102.0496, Epoch Time (s): 0.08, Total 

  0%|          | 0/1280 [00:00<?, ?it/s]

  fig, ax = plt.subplots(figsize=(4.5, 4))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5010 Train ELBO: 240.7783 Val ELBO: 105.7068, Epoch Time (s): 0.11, Total Time (hrs): 0.1449
====> Epoch: 5011 Train ELBO: 239.9130 Val ELBO: 105.8730, Epoch Time (s): 0.12, Total Time (hrs): 0.1449
====> Epoch: 5012 Train ELBO: 235.7156 Val ELBO: 94.4058, Epoch Time (s): 0.14, Total Time (hrs): 0.1450
====> Epoch: 5013 Train ELBO: 235.8875 Val ELBO: 107.5984, Epoch Time (s): 0.14, Total Time (hrs): 0.1450
====> Epoch: 5014 Train ELBO: 232.7479 Val ELBO: 102.3528, Epoch Time (s): 0.13, Total Time (hrs): 0.1451
====> Epoch: 5015 Train ELBO: 238.2370 Val ELBO: 103.6158, Epoch Time (s): 0.13, Total Time (hrs): 0.1451
====> Epoch: 5016 Train ELBO: 238.7208 Val ELBO: 99.5537, Epoch Time (s): 0.12, Total Time (hrs): 0.1451
====> Epoch: 5017 Train ELBO: 233.0760 Val ELBO: 96.7943, Epoch Time (s): 0.13, Total Time (hrs): 0.1452
====> Epoch: 5018 Train ELBO: 231.9277 Val ELBO: 88.7481, Epoch Time (s): 0.13, Total Time

  0%|          | 0/1280 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5010 Train ELBO: 239.5502 Val ELBO: 93.5747, Epoch Time (s): 0.09, Total Time (hrs): 0.1501
====> Epoch: 5011 Train ELBO: 239.8529 Val ELBO: 92.6915, Epoch Time (s): 0.08, Total Time (hrs): 0.1502
====> Epoch: 5012 Train ELBO: 235.1633 Val ELBO: 81.2261, Epoch Time (s): 0.09, Total Time (hrs): 0.1502
====> Epoch: 5013 Train ELBO: 235.5195 Val ELBO: 95.6080, Epoch Time (s): 0.09, Total Time (hrs): 0.1502
====> Epoch: 5014 Train ELBO: 231.9901 Val ELBO: 92.5864, Epoch Time (s): 0.08, Total Time (hrs): 0.1502
====> Epoch: 5015 Train ELBO: 237.9317 Val ELBO: 94.2988, Epoch Time (s): 0.08, Total Time (hrs): 0.1503
====> Epoch: 5016 Train ELBO: 239.0496 Val ELBO: 90.3967, Epoch Time (s): 0.09, Total Time (hrs): 0.1503
====> Epoch: 5017 Train ELBO: 234.5333 Val ELBO: 85.9794, Epoch Time (s): 0.08, Total Time (hrs): 0.1503
====> Epoch: 5018 Train ELBO: 233.7993 Val ELBO: 79.1602, Epoch Time (s): 0.08, Total Time (hrs

  0%|          | 0/1280 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5010 Train ELBO: 240.7318 Val ELBO: 97.5173, Epoch Time (s): 0.10, Total Time (hrs): 0.1517
====> Epoch: 5011 Train ELBO: 240.6113 Val ELBO: 93.4882, Epoch Time (s): 0.08, Total Time (hrs): 0.1518
====> Epoch: 5012 Train ELBO: 240.1535 Val ELBO: 94.9282, Epoch Time (s): 0.08, Total Time (hrs): 0.1518
====> Epoch: 5013 Train ELBO: 241.1363 Val ELBO: 98.2287, Epoch Time (s): 0.10, Total Time (hrs): 0.1518
====> Epoch: 5014 Train ELBO: 240.7176 Val ELBO: 98.2614, Epoch Time (s): 0.08, Total Time (hrs): 0.1518
====> Epoch: 5015 Train ELBO: 241.6535 Val ELBO: 96.5036, Epoch Time (s): 0.08, Total Time (hrs): 0.1518
====> Epoch: 5016 Train ELBO: 240.0726 Val ELBO: 96.5463, Epoch Time (s): 0.10, Total Time (hrs): 0.1519
====> Epoch: 5017 Train ELBO: 234.6789 Val ELBO: 90.9941, Epoch Time (s): 0.10, Total Time (hrs): 0.1519
====> Epoch: 5018 Train ELBO: 232.5122 Val ELBO: 79.3018, Epoch Time (s): 0.08, Total Time (hrs

  0%|          | 0/1280 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5010 Train ELBO: 241.2341 Val ELBO: 104.3967, Epoch Time (s): 0.08, Total Time (hrs): 0.1522
====> Epoch: 5011 Train ELBO: 238.7584 Val ELBO: 106.3112, Epoch Time (s): 0.08, Total Time (hrs): 0.1522
====> Epoch: 5012 Train ELBO: 237.2866 Val ELBO: 110.5166, Epoch Time (s): 0.10, Total Time (hrs): 0.1522
====> Epoch: 5013 Train ELBO: 241.6251 Val ELBO: 108.4410, Epoch Time (s): 0.10, Total Time (hrs): 0.1523
====> Epoch: 5014 Train ELBO: 241.6926 Val ELBO: 111.6202, Epoch Time (s): 0.08, Total Time (hrs): 0.1523
====> Epoch: 5015 Train ELBO: 238.1593 Val ELBO: 97.7087, Epoch Time (s): 0.09, Total Time (hrs): 0.1523
====> Epoch: 5016 Train ELBO: 242.3055 Val ELBO: 108.4876, Epoch Time (s): 0.08, Total Time (hrs): 0.1523
====> Epoch: 5017 Train ELBO: 241.8854 Val ELBO: 102.1859, Epoch Time (s): 0.08, Total Time (hrs): 0.1524
====> Epoch: 5018 Train ELBO: 245.6803 Val ELBO: 106.4513, Epoch Time (s): 0.09, Total T

  0%|          | 0/1280 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 5001 Train ELBO: 249.1315 Val ELBO: -376.7560, Epoch Time (s): 0.08, Total Time (hrs): 0.1519
====> Epoch: 5002 Train ELBO: 248.6711 Val ELBO: -406.5355, Epoch Time (s): 0.08, Total Time (hrs): 0.1519
====> Epoch: 5003 Train ELBO: 248.6812 Val ELBO: -407.7495, Epoch Time (s): 0.08, Total Time (hrs): 0.1519
====> Epoch: 5004 Train ELBO: 249.5035 Val ELBO: -425.3087, Epoch Time (s): 0.10, Total Time (hrs): 0.1519
====> Epoch: 5005 Train ELBO: 249.0838 Val ELBO: -471.2059, Epoch Time (s): 0.08, Total Time (hrs): 0.1520
====> Epoch: 5006 Train ELBO: 250.1667 Val ELBO: -456.2469, Epoch Time (s): 0.08, Total Time (hrs): 0.1520
====> Epoch: 5007 Train ELBO: 249.0120 Val ELBO: -434.8166, Epoch Time (s): 0.09, Total Time (hrs): 0.1520
====> Epoch: 5008 Train ELBO: 247.8002 Val ELBO: -428.7968, Epoch Time (s): 0.08, Total Time (hrs): 0.1520
====> Epoch: 5009 Train ELBO: 247.0566 Val ELBO: -498.7784, Epoch Time (s): 0.0