In [1]:
import json

import numpy as np
import pandas as pd
import plotly.graph_objs as go
from sklearn.metrics import r2_score

from shg_ml_benchmarks.utils import BENCHMARKS_DIR, load_holdout

# Load data

In [2]:
model_label = "modnet"
model_tags = "mmf_pgnn_"
task = "distribution_125"

results_fname = f"{model_tags}results.json"
results_path = BENCHMARKS_DIR / model_label / "tasks" / task / results_fname
with open(results_path, "r") as f:
    data = json.load(f)

metrics = data.pop("metrics")
df_pred_unc = pd.DataFrame(data=data)
print(df_pred_unc.shape)
display(df_pred_unc.head())

(125, 2)


Unnamed: 0,predictions,uncertainties
mp-552663,1.35848,0.589436
mp-753401,1.820289,0.876528
mp-23363,0.172901,0.112393
mp-559961,2.6371,0.809798
mp-17066,1.264187,0.604361


In [3]:
holdout_df = load_holdout(task).filter(df_pred_unc.index, axis=0)

# Parity plot - log

In [4]:
true_values = holdout_df["dKP_full_neum"].tolist()
pred_values = df_pred_unc["predictions"].tolist()

# Scatter plot for previous outputs.
scatter_plot = go.Scatter(
    x=np.log(true_values),
    y=np.log(pred_values),
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in holdout_df.index.tolist()],
)

ideal = go.Scatter(
    x=[-10, 8],
    y=[-10, 8],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="ln(<i>d</i><sub>KP</sub>) (pm/V)", range=[-8.5, 5.2]),
    yaxis=dict(title="ln(<i>d&#770;</i><sub>KP</sub>) (pm/V)", range=[-8.5, 5.2]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_plot, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        tickmode="linear",
        tick0=0,
        dtick=2,
        showgrid=False,
    ),
    yaxis=dict(
        tickmode="linear",
        tick0=0,
        dtick=2,
        showgrid=False,
    ),
)

fig.show()

In [5]:
np.mean(np.abs(np.log(true_values) - np.log(pred_values)))

0.8118740952271818

In [6]:
r2_score(np.log(true_values), np.log(pred_values))

0.738924405977687

In [8]:
r2_score(true_values, pred_values)

0.8073915856579492