In [1]:
%load_ext autoreload
%autoreload 2

# Import necessary modules 
import numpy as np
import pandas as pd
from itertools import product
import plotly.graph_objs as go
import plotly.express as px
from numpy.linalg import eigvalsh
from plotly.subplots import make_subplots
import torch
import itertools
import concurrent.futures
from tqdm import tqdm

from rbf_volatility_surface import RBFVolatilitySurface
from smoothness_prior import RBFQuadraticSmoothnessPrior
from dataset_sabr import generate_sabr_call_options
from surface_vae_trainer import SurfaceVAETrainer

In [2]:
# Define the strike price list and maturity time list
strike_price_list = np.array([0.75, 0.85, 0.9, 0.95, 1.0, 1.05, 1.1, 1.2, 1.3, 1.5])
maturity_time_list = np.array([0.02, 0.08, 0.17, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 3.0])

# Create the product grid of maturity times and strike prices
product_grid = list(product(maturity_time_list, strike_price_list))
maturity_times, strike_prices = zip(*product_grid)

# Convert to arrays for further operations
maturity_times = np.array(maturity_times)
strike_prices = np.array(strike_prices)

# Variance formula for log-uniform distribution
def log_uniform_variance(a, b):
    log_term = np.log(b / a)
    var = ((b ** 2 - a ** 2) / (2 * log_term)) - ((b - a) / log_term) ** 2
    return var

# Calculate standard deviations for maturity times and strike prices
maturity_std = np.sqrt(log_uniform_variance(maturity_time_list.min(), maturity_time_list.max()))
strike_std = np.sqrt(log_uniform_variance(strike_price_list.min(), strike_price_list.max()))

# Define the SABR model parameters
alpha = 0.20  # Stochastic volatility parameter
beta = 0.50   # Elasticity parameter
rho = -0.75   # Correlation between asset price and volatility
nu = 1.0      # Volatility of volatility parameter

# Other model parameters
risk_free_rate = np.log(1.02)  # Risk-free interest rate
underlying_price = 1.0         # Current price of the underlying asset

# Generate the dataset using the SABR model and Black-Scholes formula
call_option_dataset = generate_sabr_call_options(
    alpha=alpha,
    beta=beta,
    rho=rho,
    nu=nu,
    maturity_times=maturity_times,
    strike_prices=strike_prices,
    risk_free_rate=risk_free_rate,
    underlying_price=underlying_price
)

# Maturity times and strike prices from the previous product grid setup
hypothetical_maturity_time_list = np.logspace(np.log10(0.01), np.log10(3.1), 100)
hypothetical_strike_price_list = np.logspace(np.log10(0.7), np.log10(1.75), 100)

# Create the product grid of maturity times and strike prices
hypothetical_product_grid = list(product(hypothetical_maturity_time_list, hypothetical_strike_price_list))
hypothetical_maturity_times, hypothetical_strike_prices = zip(*hypothetical_product_grid)
hypothetical_maturity_times, hypothetical_strike_prices = np.array(hypothetical_maturity_times), np.array(hypothetical_strike_prices)

# Reshape the data for 3D surface plotting
hypothetical_maturities_grid = hypothetical_maturity_times.reshape((len(hypothetical_maturity_time_list), len(hypothetical_strike_price_list)))  
hypothetical_strikes_grid = hypothetical_strike_prices.reshape((len(hypothetical_maturity_time_list), len(hypothetical_strike_price_list)))

In [3]:
n_roots = 350
# n_roots = 10
smoothness_controller = 3.274549162877732e-05

# Initialize the RBFQuadraticSmoothnessPrior class
smoothness_prior = RBFQuadraticSmoothnessPrior(
    maturity_times=maturity_times,
    strike_prices=strike_prices,
    maturity_std=maturity_std,
    strike_std=strike_std,
    n_roots=n_roots,
    smoothness_controller=smoothness_controller,
    random_state=0,
)

prior_covariance_matrix = smoothness_prior.prior_covariance()
prior_eigenvalues = np.sort(np.linalg.eigvalsh(prior_covariance_matrix))[::-1].copy()

# The constant_volatility is set to a reasonable value
constant_volatility = RBFVolatilitySurface.calculate_constant_volatility(
    call_option_dataset["Implied Volatility"],
    call_option_dataset["Time to Maturity"],
    call_option_dataset["Strike Price"],
    risk_free_rate,
    underlying_price
)

sampled_surface_coefficients = smoothness_prior.sample_smooth_surfaces(10000)

In [4]:
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device(type='cuda')

In [13]:
latent_dim = 70  # Latent dimension
data_dim = 100  # Data dimension of input
latent_diagonal = prior_eigenvalues[:latent_dim]  # Eigenvalues for latent prior
batch_size = 1000  # Batch size for training
beta = 1.0  # Beta value for beta-VAE
fine_tune_learning_rate = 1e-4  # Fine-tune learning rate
pre_train_epochs = 350  # Number of pre-train epochs
fine_tune_epochs = 20  # Number of fine-tune epochs
device = "cpu"  # Use CPU as the device

# Define the hyperparameter grid
hidden_dim_grid = [128, 256, 512]  # Example grid for hidden_dim
n_layers_grid = [2, 4, 8]         # Example grid for n_layers
pre_train_learning_rate_grid = [1e-4, 1e-3, 1e-2]  # Example grid for learning rate

# Initialize an empty DataFrame to store the results
results_df = pd.DataFrame()

# Define the grid search
grid = itertools.product(hidden_dim_grid, n_layers_grid, pre_train_learning_rate_grid)

for hidden_dim, n_layers, pre_train_learning_rate in tqdm(grid):
    # Initialize the trainer with the specified configuration
    trainer = SurfaceVAETrainer(
        latent_dim=latent_dim,
        hidden_dim=hidden_dim,
        n_layers=n_layers,
        data_dim=data_dim,
        latent_diagonal=latent_diagonal,
        batch_size=batch_size,
        beta=beta,
        pre_train_learning_rate=pre_train_learning_rate,
        fine_tune_learning_rate=fine_tune_learning_rate,
        pre_train_epochs=pre_train_epochs,
        fine_tune_epochs=fine_tune_epochs,
        device=device,
    )

    # Train the model using pre_train
    trainer.pre_train_with_sampling(
        smoothness_prior=smoothness_prior,
        experiment_name=f"test_hd_{hidden_dim}_nl_{n_layers}_lr_{pre_train_learning_rate}"
    )

    # Retrieve the last row of the loss history (assuming it's stored in trainer.pre_train_loss_history)
    loss_df = pd.DataFrame(trainer.pre_train_loss_history)
    last_row = loss_df.iloc[-1].copy()

    # Add the configuration as columns in the last row
    last_row['hidden_dim'] = hidden_dim
    last_row['n_layers'] = n_layers
    last_row['pre_train_learning_rate'] = pre_train_learning_rate

    results_df = pd.concat([results_df, pd.DataFrame([last_row])], ignore_index=True)

0it [00:00, ?it/s]

Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 35.783843994140625, 'KL Loss': 62292.15234375, 'Total Loss': 62327.9375}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 35.19465255737305, 'KL Loss': 62269.28125, 'Total Loss': 62304.4765625}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 35.06698989868164, 'KL Loss': 62241.4296875, 'Total Loss': 62276.49609375}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Loss': 34.57255554199219, 'KL Loss': 62216.95703125, 'Total Loss': 62251.53125}
Epoch 5/350, Batch 1, Losses: {'Reconstruction Loss': 34.136234283447266, 'KL Loss': 62197.3203125, 'Total Loss': 62231.45703125}
Epoch 6/350, Batch 1, Losses: {'Reconstruction Loss': 33.781558990478516, 'KL Loss': 62183.03125, 'Total Loss': 62216.8125}
Epoch 7/350, Batch 1, Losses: {'Reconstruction Loss': 33.794708251953125, 'KL Loss': 62151.74609375, 'Total Loss': 62185.5390625}
Epoch 8/350, Batch 1, Losses: {'Reconstruction Loss': 33.21261215209961, 'KL Loss': 62132.84765625, 'Total 

1it [00:02,  2.43s/it]

Epoch 324/350, Batch 1, Losses: {'Reconstruction Loss': 1.046960473060608, 'KL Loss': 9004.162109375, 'Total Loss': 9005.208984375}
Epoch 325/350, Batch 1, Losses: {'Reconstruction Loss': 1.0275088548660278, 'KL Loss': 8938.263671875, 'Total Loss': 8939.291015625}
Epoch 326/350, Batch 1, Losses: {'Reconstruction Loss': 1.0193464756011963, 'KL Loss': 8837.9599609375, 'Total Loss': 8838.9794921875}
Epoch 327/350, Batch 1, Losses: {'Reconstruction Loss': 1.0096559524536133, 'KL Loss': 8787.01171875, 'Total Loss': 8788.021484375}
Epoch 328/350, Batch 1, Losses: {'Reconstruction Loss': 0.9915394186973572, 'KL Loss': 8677.15234375, 'Total Loss': 8678.1435546875}
Epoch 329/350, Batch 1, Losses: {'Reconstruction Loss': 0.9834733009338379, 'KL Loss': 8598.939453125, 'Total Loss': 8599.9228515625}
Epoch 330/350, Batch 1, Losses: {'Reconstruction Loss': 0.971636176109314, 'KL Loss': 8548.1494140625, 'Total Loss': 8549.12109375}
Epoch 331/350, Batch 1, Losses: {'Reconstruction Loss': 0.95208960771

2it [00:04,  2.36s/it]

Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.0521734282374382, 'KL Loss': 9.247784614562988, 'Total Loss': 9.299958229064941}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.05209615081548691, 'KL Loss': 9.141448020935059, 'Total Loss': 9.193544387817383}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.05219423770904541, 'KL Loss': 9.13918399810791, 'Total Loss': 9.191378593444824}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.051967803388834, 'KL Loss': 9.002315521240234, 'Total Loss': 9.054283142089844}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.052250418812036514, 'KL Loss': 8.738923072814941, 'Total Loss': 8.791173934936523}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.052117131650447845, 'KL Loss': 8.819272994995117, 'Total Loss': 8.871390342712402}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05197640508413315, 'KL Loss': 8.833328247070312, 'Total Loss': 8.88530445098877}
Epoch 1/350, Batch 1, L

3it [00:07,  2.37s/it]

Epoch 328/350, Batch 1, Losses: {'Reconstruction Loss': 0.053298186510801315, 'KL Loss': 0.16471032798290253, 'Total Loss': 0.21800851821899414}
Epoch 329/350, Batch 1, Losses: {'Reconstruction Loss': 0.05338768661022186, 'KL Loss': 0.16450852155685425, 'Total Loss': 0.2178962081670761}
Epoch 330/350, Batch 1, Losses: {'Reconstruction Loss': 0.05309620499610901, 'KL Loss': 0.1611703783273697, 'Total Loss': 0.2142665833234787}
Epoch 331/350, Batch 1, Losses: {'Reconstruction Loss': 0.053329579532146454, 'KL Loss': 0.16503416001796722, 'Total Loss': 0.21836373209953308}
Epoch 332/350, Batch 1, Losses: {'Reconstruction Loss': 0.0531780868768692, 'KL Loss': 0.16097499430179596, 'Total Loss': 0.21415308117866516}
Epoch 333/350, Batch 1, Losses: {'Reconstruction Loss': 0.053331419825553894, 'KL Loss': 0.1629045605659485, 'Total Loss': 0.21623598039150238}
Epoch 334/350, Batch 1, Losses: {'Reconstruction Loss': 0.05265409126877785, 'KL Loss': 0.16041439771652222, 'Total Loss': 0.2130684852600

4it [00:10,  2.58s/it]

Epoch 327/350, Batch 1, Losses: {'Reconstruction Loss': 0.08320364356040955, 'KL Loss': 250.1006317138672, 'Total Loss': 250.183837890625}
Epoch 328/350, Batch 1, Losses: {'Reconstruction Loss': 0.08235399425029755, 'KL Loss': 249.00003051757812, 'Total Loss': 249.08238220214844}
Epoch 329/350, Batch 1, Losses: {'Reconstruction Loss': 0.08147595077753067, 'KL Loss': 246.20396423339844, 'Total Loss': 246.2854461669922}
Epoch 330/350, Batch 1, Losses: {'Reconstruction Loss': 0.08235801756381989, 'KL Loss': 243.89840698242188, 'Total Loss': 243.9807586669922}
Epoch 331/350, Batch 1, Losses: {'Reconstruction Loss': 0.0812501311302185, 'KL Loss': 239.28221130371094, 'Total Loss': 239.36346435546875}
Epoch 332/350, Batch 1, Losses: {'Reconstruction Loss': 0.08085717260837555, 'KL Loss': 237.4496612548828, 'Total Loss': 237.530517578125}
Epoch 333/350, Batch 1, Losses: {'Reconstruction Loss': 0.08115128427743912, 'KL Loss': 235.83203125, 'Total Loss': 235.91317749023438}
Epoch 334/350, Batch 

5it [00:13,  2.77s/it]

Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.0524778887629509, 'KL Loss': 14.842422485351562, 'Total Loss': 14.89490032196045}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.052355486899614334, 'KL Loss': 17.694931030273438, 'Total Loss': 17.747285842895508}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 19.229373931884766, 'KL Loss': 62277.90625, 'Total Loss': 62297.13671875}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 5.767286777496338, 'KL Loss': 58377.5, 'Total Loss': 58383.265625}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 2.2841358184814453, 'KL Loss': 38025.60546875, 'Total Loss': 38027.890625}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Loss': 0.8368957042694092, 'KL Loss': 14285.0185546875, 'Total Loss': 14285.85546875}
Epoch 5/350, Batch 1, Losses: {'Reconstruction Loss': 1.7229710817337036, 'KL Loss': 20308.9609375, 'Total Loss': 20310.68359375}
Epoch 6/350, Batch 1, Losses: {'Reconstruction Loss': 0.9209786653518677, 'K

6it [00:16,  2.89s/it]

Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.052129972726106644, 'KL Loss': 0.2479742467403412, 'Total Loss': 0.3001042306423187}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.05214749276638031, 'KL Loss': 0.2458420991897583, 'Total Loss': 0.2979896068572998}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.051857076585292816, 'KL Loss': 0.2438966929912567, 'Total Loss': 0.2957537770271301}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05187784135341644, 'KL Loss': 0.2430015206336975, 'Total Loss': 0.29487937688827515}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05244484171271324, 'KL Loss': 0.24418522417545319, 'Total Loss': 0.29663005471229553}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 7.858724594116211, 'KL Loss': 62284.453125, 'Total Loss': 62292.3125}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 7.416639804840088, 'KL Loss': 62215.93359375, 'Total Loss': 62223.3515625}
Epoch 3/350, Batch 1, Losses: {'

7it [00:20,  3.45s/it]

Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.053577177226543427, 'KL Loss': 39.72312927246094, 'Total Loss': 39.77670669555664}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.053138989955186844, 'KL Loss': 39.5050163269043, 'Total Loss': 39.55815505981445}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.05369945615530014, 'KL Loss': 39.37202835083008, 'Total Loss': 39.42572784423828}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.05366150289773941, 'KL Loss': 39.12004089355469, 'Total Loss': 39.173702239990234}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.052942756563425064, 'KL Loss': 38.90363311767578, 'Total Loss': 38.95657730102539}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05354570597410202, 'KL Loss': 38.758296966552734, 'Total Loss': 38.81184387207031}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.053573768585920334, 'KL Loss': 38.58019256591797, 'Total Loss': 38.633766174316406}
Epoch 1/350, B

8it [00:25,  3.68s/it]

Epoch 338/350, Batch 1, Losses: {'Reconstruction Loss': 0.0522395595908165, 'KL Loss': 121.64987182617188, 'Total Loss': 121.70211029052734}
Epoch 339/350, Batch 1, Losses: {'Reconstruction Loss': 0.052554503083229065, 'KL Loss': 123.96495056152344, 'Total Loss': 124.01750183105469}
Epoch 340/350, Batch 1, Losses: {'Reconstruction Loss': 0.05199681967496872, 'KL Loss': 78.51912689208984, 'Total Loss': 78.57112121582031}
Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.05202696472406387, 'KL Loss': 18.07280158996582, 'Total Loss': 18.124828338623047}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.051600199192762375, 'KL Loss': 4.115444660186768, 'Total Loss': 4.167044639587402}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.051845405250787735, 'KL Loss': 37.34688949584961, 'Total Loss': 37.39873504638672}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.0519779808819294, 'KL Loss': 56.92232894897461, 'Total Loss': 56.974308013916016}
Epoch 345/350

9it [00:29,  3.83s/it]

Epoch 340/350, Batch 1, Losses: {'Reconstruction Loss': 0.052185043692588806, 'KL Loss': 29.01661491394043, 'Total Loss': 29.06879997253418}
Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.0522271953523159, 'KL Loss': 30.458602905273438, 'Total Loss': 30.51082992553711}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.0525030754506588, 'KL Loss': 28.526046752929688, 'Total Loss': 28.578550338745117}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.0523984357714653, 'KL Loss': 29.971649169921875, 'Total Loss': 30.0240478515625}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.05222531780600548, 'KL Loss': 28.237951278686523, 'Total Loss': 28.290176391601562}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.05254209041595459, 'KL Loss': 28.689754486083984, 'Total Loss': 28.74229621887207}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.052479322999715805, 'KL Loss': 28.384801864624023, 'Total Loss': 28.437280654907227}
Epoch 347/350

10it [00:32,  3.59s/it]

Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.1124337688088417, 'KL Loss': 1760.270263671875, 'Total Loss': 1760.3826904296875}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.1116362139582634, 'KL Loss': 1742.2125244140625, 'Total Loss': 1742.32421875}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.11077732592821121, 'KL Loss': 1727.5654296875, 'Total Loss': 1727.6761474609375}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.10979445278644562, 'KL Loss': 1715.8614501953125, 'Total Loss': 1715.97119140625}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.10958273708820343, 'KL Loss': 1693.831298828125, 'Total Loss': 1693.94091796875}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.10956259071826935, 'KL Loss': 1684.5533447265625, 'Total Loss': 1684.6629638671875}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.10973075032234192, 'KL Loss': 1672.1336669921875, 'Total Loss': 1672.243408203125}
Epoch 348/350, Batch 1,

11it [00:35,  3.44s/it]

Epoch 336/350, Batch 1, Losses: {'Reconstruction Loss': 0.05119527503848076, 'KL Loss': 3.905984878540039, 'Total Loss': 3.9571802616119385}
Epoch 337/350, Batch 1, Losses: {'Reconstruction Loss': 0.05083613842725754, 'KL Loss': 3.935466766357422, 'Total Loss': 3.9863028526306152}
Epoch 338/350, Batch 1, Losses: {'Reconstruction Loss': 0.05135028809309006, 'KL Loss': 3.9108643531799316, 'Total Loss': 3.962214708328247}
Epoch 339/350, Batch 1, Losses: {'Reconstruction Loss': 0.05086034536361694, 'KL Loss': 3.8000967502593994, 'Total Loss': 3.850957155227661}
Epoch 340/350, Batch 1, Losses: {'Reconstruction Loss': 0.05110646411776543, 'KL Loss': 3.6991283893585205, 'Total Loss': 3.750234842300415}
Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.05120084807276726, 'KL Loss': 3.636204719543457, 'Total Loss': 3.687405586242676}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.051193032413721085, 'KL Loss': 3.5918307304382324, 'Total Loss': 3.643023729324341}
Epoch 343/350

12it [00:38,  3.42s/it]

Epoch 328/350, Batch 1, Losses: {'Reconstruction Loss': 0.07826805859804153, 'KL Loss': 0.016704333946108818, 'Total Loss': 0.0949723944067955}
Epoch 329/350, Batch 1, Losses: {'Reconstruction Loss': 0.07775399833917618, 'KL Loss': 0.0162186436355114, 'Total Loss': 0.09397263824939728}
Epoch 330/350, Batch 1, Losses: {'Reconstruction Loss': 0.07744549214839935, 'KL Loss': 0.017100775614380836, 'Total Loss': 0.09454626590013504}
Epoch 331/350, Batch 1, Losses: {'Reconstruction Loss': 0.07747528702020645, 'KL Loss': 0.01640898548066616, 'Total Loss': 0.09388427436351776}
Epoch 332/350, Batch 1, Losses: {'Reconstruction Loss': 0.07660937309265137, 'KL Loss': 0.016190025955438614, 'Total Loss': 0.09279939532279968}
Epoch 333/350, Batch 1, Losses: {'Reconstruction Loss': 0.07714344561100006, 'KL Loss': 0.016445621848106384, 'Total Loss': 0.09358906745910645}
Epoch 334/350, Batch 1, Losses: {'Reconstruction Loss': 0.07684081792831421, 'KL Loss': 0.016388075426220894, 'Total Loss': 0.09322889

13it [00:43,  3.78s/it]

Epoch 335/350, Batch 1, Losses: {'Reconstruction Loss': 0.05383666232228279, 'KL Loss': 36.685543060302734, 'Total Loss': 36.7393798828125}
Epoch 336/350, Batch 1, Losses: {'Reconstruction Loss': 0.05404823645949364, 'KL Loss': 36.11515808105469, 'Total Loss': 36.16920471191406}
Epoch 337/350, Batch 1, Losses: {'Reconstruction Loss': 0.05382554978132248, 'KL Loss': 35.873233795166016, 'Total Loss': 35.927059173583984}
Epoch 338/350, Batch 1, Losses: {'Reconstruction Loss': 0.053708918392658234, 'KL Loss': 35.71458053588867, 'Total Loss': 35.768287658691406}
Epoch 339/350, Batch 1, Losses: {'Reconstruction Loss': 0.05360289663076401, 'KL Loss': 35.50375747680664, 'Total Loss': 35.5573616027832}
Epoch 340/350, Batch 1, Losses: {'Reconstruction Loss': 0.05368264019489288, 'KL Loss': 35.25383758544922, 'Total Loss': 35.30752182006836}
Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.053907185792922974, 'KL Loss': 34.891170501708984, 'Total Loss': 34.94507598876953}
Epoch 342/350, 

14it [00:48,  4.10s/it]

Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05216651037335396, 'KL Loss': 1.962109088897705, 'Total Loss': 2.014275550842285}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05162657052278519, 'KL Loss': 2.0685877799987793, 'Total Loss': 2.1202144622802734}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 20.02340316772461, 'KL Loss': 62261.953125, 'Total Loss': 62281.9765625}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 12.904497146606445, 'KL Loss': 55344.4453125, 'Total Loss': 55357.3515625}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 5.116997241973877, 'KL Loss': 28606.84765625, 'Total Loss': 28611.96484375}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Loss': 37.248695373535156, 'KL Loss': 20792.990234375, 'Total Loss': 20830.23828125}
Epoch 5/350, Batch 1, Losses: {'Reconstruction Loss': 20.432266235351562, 'KL Loss': 14346.2060546875, 'Total Loss': 14366.638671875}
Epoch 6/350, Batch 1, Losses: {'Reconstruction Loss': 5.8047327995

15it [00:53,  4.34s/it]

Epoch 339/350, Batch 1, Losses: {'Reconstruction Loss': 0.06378515809774399, 'KL Loss': 0.1367240697145462, 'Total Loss': 0.2005092203617096}
Epoch 340/350, Batch 1, Losses: {'Reconstruction Loss': 0.063313789665699, 'KL Loss': 0.1327120065689087, 'Total Loss': 0.1960257887840271}
Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.06311381608247757, 'KL Loss': 0.13335543870925903, 'Total Loss': 0.196469247341156}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.06331650167703629, 'KL Loss': 0.13192296028137207, 'Total Loss': 0.19523945450782776}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.06352193653583527, 'KL Loss': 0.13207946717739105, 'Total Loss': 0.19560140371322632}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.06305763870477676, 'KL Loss': 0.1304318904876709, 'Total Loss': 0.19348952174186707}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.06358928233385086, 'KL Loss': 0.12750312685966492, 'Total Loss': 0.19109240174293518}
E

16it [01:00,  5.36s/it]

Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.05161155015230179, 'KL Loss': 13.23655891418457, 'Total Loss': 13.28817081451416}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.05183123052120209, 'KL Loss': 13.199185371398926, 'Total Loss': 13.251016616821289}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.05195095017552376, 'KL Loss': 13.113082885742188, 'Total Loss': 13.165034294128418}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.05171671137213707, 'KL Loss': 13.08944320678711, 'Total Loss': 13.141160011291504}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.05201737955212593, 'KL Loss': 13.035688400268555, 'Total Loss': 13.087705612182617}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.0514620877802372, 'KL Loss': 12.924370765686035, 'Total Loss': 12.97583293914795}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.05116961523890495, 'KL Loss': 12.797392845153809, 'Total Loss': 12.848562240600586}
Epoch 348/3

17it [01:08,  6.15s/it]

Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.0515095479786396, 'KL Loss': 0.48039525747299194, 'Total Loss': 0.5319048166275024}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.0514688715338707, 'KL Loss': 0.4750702679157257, 'Total Loss': 0.526539146900177}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.05213850736618042, 'KL Loss': 0.4705789089202881, 'Total Loss': 0.5227174162864685}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.05187954381108284, 'KL Loss': 0.48416072130203247, 'Total Loss': 0.5360402464866638}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.05148284137248993, 'KL Loss': 0.5061107277870178, 'Total Loss': 0.557593584060669}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.05152849853038788, 'KL Loss': 0.5091955661773682, 'Total Loss': 0.5607240796089172}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.05222326144576073, 'KL Loss': 0.5120720267295837, 'Total Loss': 0.5642952919006348}
Epoch 34

18it [01:17,  6.86s/it]

Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.05131435766816139, 'KL Loss': 12.939525604248047, 'Total Loss': 12.990839958190918}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.05132294446229935, 'KL Loss': 12.524736404418945, 'Total Loss': 12.576059341430664}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.051270145922899246, 'KL Loss': 12.693319320678711, 'Total Loss': 12.744589805603027}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05143524333834648, 'KL Loss': 12.482709884643555, 'Total Loss': 12.53414535522461}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05130962282419205, 'KL Loss': 12.442724227905273, 'Total Loss': 12.494033813476562}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 22.86607551574707, 'KL Loss': 62258.99609375, 'Total Loss': 62281.86328125}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 21.977439880371094, 'KL Loss': 62213.8359375, 'Total Loss': 62235.8125}
Epoch 3/350, Batch 1, Losses: {'Re

19it [01:23,  6.62s/it]

Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.05186164751648903, 'KL Loss': 387.8343505859375, 'Total Loss': 387.8861999511719}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.051984723657369614, 'KL Loss': 384.3451843261719, 'Total Loss': 384.39715576171875}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05157956853508949, 'KL Loss': 382.35870361328125, 'Total Loss': 382.4102783203125}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05181889235973358, 'KL Loss': 378.64031982421875, 'Total Loss': 378.692138671875}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 22.369705200195312, 'KL Loss': 62263.17578125, 'Total Loss': 62285.546875}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 14.704625129699707, 'KL Loss': 61812.66796875, 'Total Loss': 61827.37109375}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 9.32398796081543, 'KL Loss': 61169.48828125, 'Total Loss': 61178.8125}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Loss'

20it [01:29,  6.44s/it]

Epoch 338/350, Batch 1, Losses: {'Reconstruction Loss': 0.05212223529815674, 'KL Loss': 2.680366039276123, 'Total Loss': 2.7324881553649902}
Epoch 339/350, Batch 1, Losses: {'Reconstruction Loss': 0.05178684741258621, 'KL Loss': 2.416830539703369, 'Total Loss': 2.4686174392700195}
Epoch 340/350, Batch 1, Losses: {'Reconstruction Loss': 0.052310697734355927, 'KL Loss': 2.1481499671936035, 'Total Loss': 2.20046067237854}
Epoch 341/350, Batch 1, Losses: {'Reconstruction Loss': 0.051730651408433914, 'KL Loss': 2.1910321712493896, 'Total Loss': 2.242762804031372}
Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.051929447799921036, 'KL Loss': 2.424981117248535, 'Total Loss': 2.4769105911254883}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.05152466893196106, 'KL Loss': 2.4457762241363525, 'Total Loss': 2.497300863265991}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.05210856348276138, 'KL Loss': 2.2811360359191895, 'Total Loss': 2.333244562149048}
Epoch 345/3

21it [01:34,  6.19s/it]

Epoch 342/350, Batch 1, Losses: {'Reconstruction Loss': 0.08711682260036469, 'KL Loss': 0.008332079276442528, 'Total Loss': 0.09544890373945236}
Epoch 343/350, Batch 1, Losses: {'Reconstruction Loss': 0.08730877935886383, 'KL Loss': 0.008264318108558655, 'Total Loss': 0.09557309746742249}
Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.08702004700899124, 'KL Loss': 0.008371353149414062, 'Total Loss': 0.0953914001584053}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.08665198087692261, 'KL Loss': 0.008279676549136639, 'Total Loss': 0.09493165463209152}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.0868726298213005, 'KL Loss': 0.008216617628932, 'Total Loss': 0.09508924931287766}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.08611618727445602, 'KL Loss': 0.008250354789197445, 'Total Loss': 0.09436654299497604}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.08629117906093597, 'KL Loss': 0.008223766461014748, 'Total Loss': 0.094514943

22it [01:45,  7.48s/it]

Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.05038776993751526, 'KL Loss': 9.259368896484375, 'Total Loss': 9.3097562789917}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05038949474692345, 'KL Loss': 9.250019073486328, 'Total Loss': 9.300408363342285}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.050912026315927505, 'KL Loss': 9.276610374450684, 'Total Loss': 9.327522277832031}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 14.769455909729004, 'KL Loss': 62262.015625, 'Total Loss': 62276.78515625}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 5.787454128265381, 'KL Loss': 61493.9765625, 'Total Loss': 61499.765625}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 2.7531304359436035, 'KL Loss': 59608.4609375, 'Total Loss': 59611.21484375}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Loss': 2.0752525329589844, 'KL Loss': 56078.8046875, 'Total Loss': 56080.87890625}
Epoch 5/350, Batch 1, Losses: {'Reconstruction Loss': 1.921807169

23it [01:56,  8.56s/it]

Epoch 344/350, Batch 1, Losses: {'Reconstruction Loss': 0.052261337637901306, 'KL Loss': 0.5956524014472961, 'Total Loss': 0.6479137539863586}
Epoch 345/350, Batch 1, Losses: {'Reconstruction Loss': 0.052558258175849915, 'KL Loss': 0.5843269228935242, 'Total Loss': 0.6368851661682129}
Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.05214058980345726, 'KL Loss': 0.5893219113349915, 'Total Loss': 0.641462504863739}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.052049074321985245, 'KL Loss': 0.581742525100708, 'Total Loss': 0.6337916254997253}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.05260421708226204, 'KL Loss': 0.5886766910552979, 'Total Loss': 0.6412808895111084}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05227208882570267, 'KL Loss': 0.5721477270126343, 'Total Loss': 0.6244198083877563}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.0522773377597332, 'KL Loss': 0.5810750722885132, 'Total Loss': 0.6333523988723755}
Epoch 

24it [02:08,  9.68s/it]

Epoch 346/350, Batch 1, Losses: {'Reconstruction Loss': 0.05666546896100044, 'KL Loss': 13.19870662689209, 'Total Loss': 13.255372047424316}
Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.056554313749074936, 'KL Loss': 12.663825988769531, 'Total Loss': 12.720379829406738}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.055639106780290604, 'KL Loss': 12.10397720336914, 'Total Loss': 12.159616470336914}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.057087332010269165, 'KL Loss': 12.186141014099121, 'Total Loss': 12.2432279586792}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05729350447654724, 'KL Loss': 12.541232109069824, 'Total Loss': 12.598526000976562}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 7.860693454742432, 'KL Loss': 62256.84765625, 'Total Loss': 62264.70703125}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 6.161736488342285, 'KL Loss': 62118.77734375, 'Total Loss': 62124.9375}
Epoch 3/350, Batch 1, Losses: {'Rec

25it [02:28, 12.69s/it]

Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.05072923004627228, 'KL Loss': 2.842956066131592, 'Total Loss': 2.8936853408813477}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 7.885909557342529, 'KL Loss': 62248.9609375, 'Total Loss': 62256.84765625}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 3.2083940505981445, 'KL Loss': 60296.9375, 'Total Loss': 60300.14453125}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 1.7205207347869873, 'KL Loss': 49684.15625, 'Total Loss': 49685.875}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Loss': 0.9303736090660095, 'KL Loss': 29192.275390625, 'Total Loss': 29193.205078125}
Epoch 5/350, Batch 1, Losses: {'Reconstruction Loss': 0.4295249283313751, 'KL Loss': 11482.3916015625, 'Total Loss': 11482.8212890625}
Epoch 6/350, Batch 1, Losses: {'Reconstruction Loss': 0.5216897130012512, 'KL Loss': 10501.798828125, 'Total Loss': 10502.3203125}
Epoch 7/350, Batch 1, Losses: {'Reconstruction Loss': 2.346466064453125, 'KL Loss

26it [02:48, 14.77s/it]

Epoch 347/350, Batch 1, Losses: {'Reconstruction Loss': 0.0518423430621624, 'KL Loss': 0.26366397738456726, 'Total Loss': 0.31550630927085876}
Epoch 348/350, Batch 1, Losses: {'Reconstruction Loss': 0.05149812251329422, 'KL Loss': 0.2607557475566864, 'Total Loss': 0.3122538626194}
Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': 0.05175979062914848, 'KL Loss': 0.2607909142971039, 'Total Loss': 0.31255069375038147}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': 0.051969531923532486, 'KL Loss': 0.2592069208621979, 'Total Loss': 0.31117644906044006}
Epoch 1/350, Batch 1, Losses: {'Reconstruction Loss': 7.973967552185059, 'KL Loss': 62257.8984375, 'Total Loss': 62265.87109375}
Epoch 2/350, Batch 1, Losses: {'Reconstruction Loss': 27.37358283996582, 'KL Loss': 49337.30078125, 'Total Loss': 49364.67578125}
Epoch 3/350, Batch 1, Losses: {'Reconstruction Loss': 4079.843994140625, 'KL Loss': 95761488.0, 'Total Loss': 95765568.0}
Epoch 4/350, Batch 1, Losses: {'Reconstruction Los

27it [03:08,  6.99s/it]

Epoch 349/350, Batch 1, Losses: {'Reconstruction Loss': nan, 'KL Loss': nan, 'Total Loss': nan}
Epoch 350/350, Batch 1, Losses: {'Reconstruction Loss': nan, 'KL Loss': nan, 'Total Loss': nan}





In [14]:
# Rank the losses for each column (except 'Total Loss')
ranked_losses = results_df.drop(columns=['Total Loss', 'hidden_dim', 'n_layers', 'pre_train_learning_rate']).rank()

ranked_df = results_df.copy()

# Compute the average rank for each configuration
ranked_df['average_rank'] = ranked_losses.mean(axis=1)

# Sort by the average rank (lower is better)
ranked_df = ranked_df.sort_values('average_rank')

# Print the top-ranked configurations
ranked_df

Unnamed: 0,Reconstruction Loss,KL Loss,Total Loss,hidden_dim,n_layers,pre_train_learning_rate,average_rank
24,0.050729,2.842956,2.893685,512.0,8.0,0.0001,5.5
16,0.051739,0.69014,0.741879,256.0,8.0,0.001,7.0
13,0.051627,2.068588,2.120214,256.0,4.0,0.001,7.0
25,0.05197,0.259207,0.311176,512.0,8.0,0.001,8.0
21,0.050912,9.27661,9.327522,512.0,4.0,0.0001,8.5
10,0.051512,3.765398,3.816909,256.0,2.0,0.001,8.5
7,0.051636,3.176954,3.22859,128.0,8.0,0.001,9.0
19,0.051809,3.002637,3.054446,512.0,2.0,0.001,9.5
17,0.05131,12.442724,12.494034,256.0,8.0,0.01,9.5
22,0.052277,0.581075,0.633352,512.0,4.0,0.001,10.0


In [23]:
latent_dim = 70  # Latent dimension
data_dim = 100  # Data dimension of input
hidden_dim = 512
n_layers = 8
latent_diagonal = prior_eigenvalues[:latent_dim]  # Eigenvalues for latent prior
batch_size = 1000  # Batch size for training
beta = 1.0  # Beta value for beta-VAE
pre_train_learning_rate = 1e-3
fine_tune_learning_rate = 1e-4  # Fine-tune learning rate
pre_train_epochs = 700  # Number of pre-train epochs
fine_tune_epochs = 20  # Number of fine-tune epochs
device = "cpu"  # Use CPU as the device

trainer = SurfaceVAETrainer(
    latent_dim=latent_dim,
    hidden_dim=hidden_dim,
    n_layers=n_layers,
    data_dim=data_dim,
    latent_diagonal=latent_diagonal,
    batch_size=batch_size,
    beta=beta,
    pre_train_learning_rate=pre_train_learning_rate,
    fine_tune_learning_rate=fine_tune_learning_rate,
    pre_train_epochs=pre_train_epochs,
    fine_tune_epochs=fine_tune_epochs,
    device=device,
)

# Train the model using pre_train
trainer.pre_train_with_sampling(
    smoothness_prior=smoothness_prior,
    experiment_name="test vae"
)

Epoch 1/700, Batch 1, Losses: {'Reconstruction Loss': 6.7736005783081055, 'KL Loss': 62260.05859375, 'Total Loss': 62266.83203125}
Epoch 2/700, Batch 1, Losses: {'Reconstruction Loss': 2.980165719985962, 'KL Loss': 61058.8125, 'Total Loss': 61061.79296875}
Epoch 3/700, Batch 1, Losses: {'Reconstruction Loss': 1.6296354532241821, 'KL Loss': 53239.84765625, 'Total Loss': 53241.4765625}
Epoch 4/700, Batch 1, Losses: {'Reconstruction Loss': 0.9533772468566895, 'KL Loss': 35942.7265625, 'Total Loss': 35943.6796875}
Epoch 5/700, Batch 1, Losses: {'Reconstruction Loss': 0.44787704944610596, 'KL Loss': 14391.67578125, 'Total Loss': 14392.1240234375}
Epoch 6/700, Batch 1, Losses: {'Reconstruction Loss': 0.5702638626098633, 'KL Loss': 11079.384765625, 'Total Loss': 11079.955078125}
Epoch 7/700, Batch 1, Losses: {'Reconstruction Loss': 1.3588809967041016, 'KL Loss': 16211.6484375, 'Total Loss': 16213.0068359375}
Epoch 8/700, Batch 1, Losses: {'Reconstruction Loss': 2.0987603664398193, 'KL Loss': 

In [24]:
loss_history = pd.DataFrame(trainer.pre_train_loss_history)

# Create a subplot figure with 1x2 grid for individual losses, and a second row spanning the entire width for total loss
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=("Reconstruction Loss", "KL Loss", "Total Loss"),
    specs=[[{'type': 'scatter'}, {'type': 'scatter'}],
           [{'colspan': 2, 'type': 'scatter'}, None]],
    vertical_spacing=0.1,
    horizontal_spacing=0.1
)

# Add traces for individual losses
fig.add_trace(go.Scatter(x=loss_history.index, y=loss_history["Reconstruction Loss"], mode="lines", name="Reconstruction Loss"), row=1, col=1)
fig.add_trace(go.Scatter(x=loss_history.index, y=loss_history["KL Loss"], mode="lines", name="KL Loss"), row=1, col=2)

# Add a trace for the total loss spanning the entire second row
fig.add_trace(go.Scatter(x=loss_history.index, y=loss_history["Total Loss"], mode="lines", name="Total Loss"), row=2, col=1)

# Update the layout to include 'Iterations' as the x-axis name for each subplot
fig.update_xaxes(title_text="Iterations", row=1, col=1)
fig.update_xaxes(title_text="Iterations", row=1, col=2)
fig.update_xaxes(title_text="Iterations", row=2, col=1)  # The third row spans two columns

fig.update_yaxes(type="log", row=1, col=1)
fig.update_yaxes(type="log", row=1, col=2)
fig.update_yaxes(type="log", row=2, col=1)  # The third row spans two columns

# Update the layout
fig.update_layout(height=900, width=900, title_text="Beta-VAE Training Losses", showlegend=False)

# Show the plot
fig.show()