In [None]:
import plotly.graph_objects as go

# -----------------------------
# DATA ORGANIZATION
# -----------------------------
parameter_name = "heterogeneous graph ZPVE meV"

# Baseline (benchmark) model: Pamnet
baseline_model = "Pamnet (benchmark)"
baseline_mae   = 0.001343

# Other models and their MAEs
models = {
    "directed connected backbone": 0.002705,
    "undirected connected backbone": 0.002297,
    "fully connected ": 0.003228,
    "fully connected + frequency": 0.001452,
    "Faenet fully connected + frequency": 0.001634
}

# -----------------------------
# COMPUTE DELTAS FROM BASELINE
# -----------------------------
model_names = []
model_deltas = []

for model_name, mae_value in models.items():
    if mae_value > baseline_mae:
        delta = baseline_mae - mae_value
        
    else:
        
        delta = mae_value - baseline_mae
   
    model_names.append(model_name)
    model_deltas.append(delta)

# -----------------------------
# ASSIGN COLORS
# -----------------------------
# Green if delta > 0 (MAE is higher than baseline),
# Red if delta < 0 (MAE is lower than baseline).
colors = ["green" if d > 0 else "red" for d in model_deltas]

# -----------------------------
# BUILD THE PLOTLY BAR CHART
# -----------------------------
fig = go.Figure()

fig.add_trace(
    go.Bar(
        x=model_names,
        y=model_deltas,
        marker_color=colors,
        text=[f"{delta:.6f}" for delta in model_deltas],  # Show delta up to 6 decimals
        textposition='outside'
    )
)

# Add a horizontal line at y=0 to represent the baseline
fig.update_layout(
    title=f"{parameter_name} — MAE Δ from {baseline_model} ({baseline_mae:.6f} 100 Epochs)",
    xaxis_title="Models",
    yaxis_title="MAE Delta from Baseline [meV]",
    shapes=[
        dict(
            type='line',
            xref='paper', x0=0, x1=1,
            yref='y', y0=0, y1=0,
            line=dict(color='black', width=2)
        )
    ],
    # Increase bottom margin to move labels "lower"
    margin=dict(l=60, r=60, t=80, b=10),
    paper_bgcolor='white',
    plot_bgcolor='white'
)

# Adjust axis settings
# (We remove tickpad since older Plotly doesn't support it)
fig.update_xaxes(
    ticks='outside',
    tickangle=0,            
    tickfont=dict(size=10),
    automargin=True
)

fig.update_yaxes(
    tickfont=dict(size=10),
    automargin=True
)

fig.show()

# Uncomment to save figure as PNG (requires kaleido: pip install kaleido)
# fig.write_image("my_figure.png", width=800, height=600)

