This notebook visualizes loss curves obtained from the training of the LatentVAE.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
df = pd.read_csv('losses/latent_vae_v17_val_losses.csv')

In [None]:
# Perceptual and Reconstruction Losses
fig, ax1 = plt.subplots(figsize=(10, 5))

# reconstruction losses on the left axis
ax1.plot(df['epoch'], df['val/rec_img_loss'],  label='Image Reconstruction Loss', color='green')
ax1.plot(df['epoch'], df['val/rec_lat_loss'],  label='Latent Reconstruction Loss', color='green', linestyle='--')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Reconstruction Losses', color='green')
ax1.tick_params(axis='y', colors='green')
ax1.grid(True)

# perceptual loss on the right axis
ax2 = ax1.twinx()
ax2.plot(df['epoch'], df['val/perc_img_loss'], label='Perceptual Loss', color='blue')
ax2.set_ylabel('Perceptual Loss', color='blue')
ax2.tick_params(axis='y', colors='blue')

# combine legends
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

plt.title('Validation Loss Over Epochs')
plt.show()

In [None]:
import math

# Load losses for the other model versions
df_v16 = pd.read_csv('losses/latent_vae_v16_sj_val_losses.csv')
df_v17 = pd.read_csv('losses/latent_vae_v17_val_losses.csv')
df_v18 = pd.read_csv('losses/latent_vae_v18_val_losses.csv')
df_v19 = pd.read_csv('losses/latent_vae_v19_val_losses.csv')
df_v20 = pd.read_csv('losses/latent_vae_v20_val_losses.csv')

# Preprocess disc loss (multiply with disc_active)
df_v16['val/disc_loss'] = df_v16['val/disc_loss'] * df_v16['val/disc_active']
df_v17['val/disc_loss'] = df_v17['val/disc_loss'] * df_v17['val/disc_active']
df_v18['val/disc_loss'] = df_v18['val/disc_loss'] * df_v18['val/disc_active']
df_v19['val/disc_loss'] = df_v19['val/disc_loss'] * df_v19['val/disc_active']
df_v20['val/disc_loss'] = df_v20['val/disc_loss'] * df_v20['val/disc_active']

# List of losses to plot
loss_cols = [
    'val/disc_loss', 'val/gen_loss', 'val/kl_loss', 'val/nll_loss',
    'val/perc_img_loss', 'val/rec_img_loss', 'val/rec_lat_loss', 'val/total_loss'
]

# Create one subplot per loss
n_cols = 2
n_rows = math.ceil(len(loss_cols) / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(10 * n_cols, 4 * n_rows), sharex=True)
axes = axes.flatten()

for ax, loss in zip(axes, loss_cols):
    ax.plot(df_v16['step'], df_v16[loss], label='v16')
    ax.plot(df_v17['step'], df_v17[loss], label='v17')
    ax.plot(df_v18['step'], df_v18[loss], label='v18')
    ax.plot(df_v19['step'], df_v19[loss], label='v19')
    ax.plot(df_v20['step'], df_v20[loss], label='v20')
    ax.grid(True)
    ax.legend(loc='upper right')
    ax.set_title(loss)

axes[-1].set_xlabel('Steps')
plt.tight_layout()
plt.show()