In [1]:
from pathlib import Path

from scipy.stats import spearmanr
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    root_mean_squared_error,
)

In [2]:
import os

import numpy as np
import pandas as pd
import plotly.graph_objs as go

In [3]:
import shg_ml_benchmarks.utils_shg as shg
from shg_ml_benchmarks.utils import load_full

In [4]:
df_orig = load_full()
df_pred_matten = pd.read_json("df_pred_matten_holdout.json.gz")

# Computing dKP from the Matten tensors predictions
list_dKP_matten = []
list_dRMS_matten = []
list_dKP_matten_masked = []
list_dRMS_matten_masked = []
for ir, r in df_pred_matten.iterrows():
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    list_dKP_matten.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten.append(shg.get_dRMS(dijk_matten))

    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    list_dKP_matten_masked.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten_masked.append(shg.get_dRMS(dijk_matten))

df_pred_matten["dKP_matten"] = list_dKP_matten
df_pred_matten["dRMS_matten"] = list_dRMS_matten
df_pred_matten["dKP_matten_masked"] = list_dKP_matten_masked
df_pred_matten["dRMS_matten_masked"] = list_dRMS_matten_masked

# Adding the true dKP to the df
df_pred_matten["dKP_true"] = df_orig.filter(df_pred_matten.index, axis=0)[
    "dKP_full_neum"
].tolist()
df_pred_matten["dRMS_true"] = [
    shg.get_dRMS(d)
    for d in df_orig.filter(df_pred_matten.index, axis=0)["dijk_full_neum"].tolist()
]

print(df_pred_matten.shape)
display(df_pred_matten.head())

(250, 7)


Unnamed: 0,dijk_matten,dKP_matten,dRMS_matten,dKP_matten_masked,dRMS_matten_masked,dKP_true,dRMS_true
agm001234439,"[[[2.8334748745000002, 0.030544720600000002, 0...",1.403993,0.780324,1.403993,0.780324,0.827288,0.453301
agm001347560,"[[[5.8048700000000005e-05, 8.793300000000001e-...",189.072237,105.459455,189.072237,105.459455,119.177507,66.474038
agm002041688,"[[[5.156993866, 3.0818479061, -5.3886313438], ...",15.653956,7.101285,15.653956,7.101285,5.595534,2.876049
agm002054632,"[[[-24.1034183502, -3.5733189583, -1.764526367...",24.358271,13.3868,24.358271,13.3868,45.303136,21.07202
agm002072314,"[[[-4.614e-06, -1.298e-07, -0.5090533495], [-1...",0.509487,0.260996,0.509487,0.260996,3.674658,1.610681


In [5]:
type_set = os.path.basename(Path(os.getcwd())).split("predict_")[1]

df_lr = pd.read_csv(f"../scripts_{type_set}/lightning_logs/version_0/metrics.csv")
print(df_lr.shape)
display(df_lr.head())

(1001, 13)


Unnamed: 0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
0,48.687542,0,12.907778,,5.426843,74,,,,,600.077026,5.426843,600.077026
1,,0,,,,74,,,709.328308,709.328308,,,
2,61.04554,1,12.357997,,5.269061,149,,,,,531.813293,5.269061,531.813293
3,,1,,,,149,,,627.722534,627.722534,,,
4,73.375443,2,12.329903,,5.120953,224,,,,,495.29425,5.120953,495.29425


In [6]:
df_lr_epochs_first = df_lr[df_lr["epoch"].duplicated(keep="last")]
df_lr_epochs_first.index = df_lr_epochs_first["epoch"]
display(df_lr_epochs_first.head())
df_lr_epochs_second = df_lr[df_lr["epoch"].duplicated(keep="first")]
df_lr_epochs_second.index = df_lr_epochs_second["epoch"]
display(df_lr_epochs_second.head())

Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,48.687542,0,12.907778,,5.426843,74,,,,,600.077026,5.426843,600.077026
1,61.04554,1,12.357997,,5.269061,149,,,,,531.813293,5.269061,531.813293
2,73.375443,2,12.329903,,5.120953,224,,,,,495.29425,5.120953,495.29425
3,85.751953,3,12.376513,,5.183712,299,,,,,554.202576,5.183712,554.202576
4,98.145844,4,12.393886,,5.212463,374,,,,,591.223206,5.212463,591.223206


Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,,0,,,,74,,,709.328308,709.328308,,,
1,,1,,,,149,,,627.722534,627.722534,,,
2,,2,,,,224,,,504.990906,504.990906,,,
3,,3,,,,299,,,359.805542,359.805542,,,
4,,4,,,,374,,,286.304962,286.304962,,,


# Learning rates

## Loss

In [7]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["val/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)
lr_train = go.Scatter(
    x=df_lr_epochs_second.index,
    y=df_lr_epochs_second["train/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Train loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    # yaxis=dict(title='<i>d&#770;</i><sub>KP</sub> (pm/V)', range=[-5,180]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_train, lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

## MAE

In [8]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["metric_val/MeanAbsoluteError/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    yaxis=dict(title="MAE (pm/V)"),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

# Errors on dijk

In [9]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

32.56678429175125


In [10]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

3.8486072897105013


## Errors on dijk masked

In [11]:
# masked
list_fronorm = []  # = sqrt(sum(sqrd))
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

30.9475236323985


In [12]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

3.5892438538033415


# Errors on dKP

In [13]:
# Data

mae = mean_absolute_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"]
)
rmse = root_mean_squared_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 150]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 150]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 8.154118007629796
MAPE = 7.2082266636523045
RMSE = 20.61773030443456
Rho_sp = 0.7870278884462151


## Errors on dKP masked

In [14]:
# Data

mae = mean_absolute_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten_masked"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 150]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 150]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 8.712632685840285
MAPE = 7.137269740478439
RMSE = 21.742057242309183
Rho_sp = 0.6331020976335621


# Errors on dRMS

In [15]:
# Data

mae = mean_absolute_error(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
spearmanrho = spearmanr(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dRMS = go.Scatter(
    x=df_pred_matten["dRMS_true"],
    y=df_pred_matten["dRMS_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    yaxis=dict(title="<i>d&#770;</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dRMS, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 4.290890359798537
MAPE = 7.070167103524031
RMSE = 10.783131053594627
Rho_sp = 0.7906882350117601
