In [25]:
import joblib
import optuna
from plotly.io import show

# load your study
study = joblib.load("./saved_network/optuna_study.pkl")

# make the basic 2D Pareto front
fig2d = optuna.visualization.plot_pareto_front(
    study,
    targets=lambda t: (t.values[0], t.values[2]),
    target_names=["L_inv", "L_pred"]
)

# Trace 0 is the dominated trials; Trace 1 is the Pareto‐front
dominated = fig2d.data[0]
pareto     = fig2d.data[1]

# 1) hide the colorbar for the dominated‐trial trace
dominated.marker.showscale = False
# 2) force them to all use one color instead of a gradient
dominated.marker.color = "lightgray"

# (optional) if you wanted to tweak the Pareto‐front colorbar look:
# pareto.marker.colorbar.title = "your title"

fig2d.show()


In [2]:
# 3D Pareto front: L_inv, L_pred, L_latent
fig3d = optuna.visualization.plot_pareto_front(
    study,
    targets=lambda t: (t.values[0], t.values[1], t.values[2]),
    target_names=["L_inv", "L_pred", "L_latent"]
)
show(fig3d)

In [4]:
import joblib
import optuna
import numpy as np
import plotly.graph_objects as go

study = joblib.load("./saved_network/optuna_study.pkl")

# Indices of your targets
ix_Linv, ix_Lpred, ix_Llatent = 0, 1, 2

# Get all trials (or use best_trials for Pareto only)
all_trials = [t for t in study.trials if t.values is not None]
points = np.array([[t.values[ix_Linv], t.values[ix_Lpred], t.values[ix_Llatent]] for t in all_trials])

# Choose the "best" point (e.g., min sum, or min L_inv)
# Example: minimize sum
best_idx = np.argmin(points.sum(axis=1))
# Example: minimize L_inv only
# best_idx = np.argmin(points[:, 0])

best_point = points[best_idx]

# Get Plotly fig from Optuna
fig = optuna.visualization.plot_pareto_front(
    study,
    targets=lambda t: (t.values[ix_Linv], t.values[ix_Lpred], t.values[ix_Llatent]),
    target_names=["L_inv", "L_pred", "L_latent"]
)

fig.add_trace(
    go.Scatter3d(
        x=[best_point[0]],
        y=[best_point[1]],
        z=[best_point[2]],
        mode='markers+text',
        marker=dict(size=10, color='black', symbol='diamond-open'),  # 'circle', 'cross', 'x', etc. allowed
        text=["Best"],
        textposition="top center",
        name="Best"
    )
)

fig.show()
