In [1]:
import matplotlib.pyplot as plt
import numpy as np

Plot test results for the stock.

In [2]:
vae_d3 = [1.9547979831695557, 0.04277116060256958, 0.3127082884311676]
betavae_d3 = [1.9801794290542603, -0.013278365135192871, 0.3105798661708832]
infovae_d3 = [2.082566976547241, 0.3941463232040405, 0.31102630496025085]
wae_d3 = [1.9618104696273804, 0.01708132028579712, 0.31478625535964966]

vae_d4 = [1.8204433917999268, 1.243423581123352, 0.27283793687820435]
betavae_d4 = [1.8724522590637207, 2.002256393432617, 0.272980660200119]
infovae_d4 = [1.9158434867858887, 2.6492464542388916, 0.27292272448539734]
wae_d4 = [1.8207640647888184, 1.190171241760254, 0.27292928099632263]

vae_d5 = [2.0481173992156982, 2.925053358078003, 0.26442262530326843]
betavae_d5 = [2.178828001022339, 4.70975923538208, 0.26442262530326843]
infovae_d5 = [2.0393199920654297, 2.9667372703552246, 0.26442378759384155]
wae_d5 = [2.102574110031128, 3.4839935302734375, 0.2644226849079132]

data = {
    'VAE': [vae_d3, vae_d4, vae_d5],
    'Beta-VAE': [betavae_d3, betavae_d4, betavae_d5],
    'Info-VAE': [infovae_d3, infovae_d4, infovae_d5],
    'WAE': [wae_d3, wae_d4, wae_d5]
}
models = ['VAE', 'Beta-VAE', 'Info-VAE', 'WAE']
metrics = ['L2', 'JS', 'MMD']
outputs = ['$d=3$', '$d=4$', '$d=5$']

In [3]:
# Create one figure per model
for model in models:
    fig, ax = plt.subplots(figsize=(8, 5))
    bar_width = 0.25
    index = np.arange(len(outputs))

    # Plot bars for each metric
    for i, metric in enumerate(metrics):
        metric_values = [data[model][j][i] for j in range(3)]  # Extract metric across outputs
        ax.bar(index + i * bar_width, metric_values, bar_width, label=metric)

    # Customize
    ax.set_xlabel('Signature degree')
    ax.set_ylabel('Metric Values')
    # ax.set_title(f'Comparison of Metrics for {model} Across Outputs')
    ax.set_xticks(index + bar_width * (len(metrics) - 1) / 2)
    ax.set_xticklabels(outputs)
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.7)

    # Save
    import os
    os.makedirs("assets", exist_ok=True)
    plt.savefig(f'assets/sig_comparison/{model}_across_dergee.png', dpi=300, bbox_inches='tight')
    plt.close()  # Close figure to free memory

In [4]:
# Create one figure per output
for i, output in enumerate(outputs):
    fig, ax = plt.subplots(figsize=(10, 6))
    bar_width = 0.25
    index = np.arange(len(models))

    # Plot bars for each metric
    for j, metric in enumerate(metrics):
        metric_values = [data[model][i][j] for model in models]  # Extract metric for this output
        ax.bar(index + j * bar_width, metric_values, bar_width, label=metric)

    # Customize
    ax.set_xlabel('Models')
    ax.set_ylabel('Metric Values')
    # ax.set_title(f'Comparison of Metrics Across Models for {output}')
    ax.set_xticks(index + bar_width * (len(metrics) - 1) / 2)
    ax.set_xticklabels(models)
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.7)

    # Save
    import os
    os.makedirs("assets", exist_ok=True)
    plt.savefig(f'assets/model_comparison/metrics_{output.replace(" ", "_").lower()}.png', dpi=300, bbox_inches='tight')
    plt.close()  # Close figure to free memory