In [1]:
from pathlib import Path

from scipy.stats import spearmanr
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    root_mean_squared_error,
)

In [2]:
import os

import numpy as np
import pandas as pd
import plotly.graph_objs as go

In [3]:
import shg_ml_benchmarks.utils_shg as shg
from shg_ml_benchmarks.utils import load_full

In [4]:
df_orig = load_full()
df_pred_matten = pd.read_json("df_pred_matten_holdout.json.gz")

# Computing dKP from the Matten tensors predictions
list_dKP_matten = []
list_dRMS_matten = []
list_dKP_matten_masked = []
list_dRMS_matten_masked = []
for ir, r in df_pred_matten.iterrows():
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    list_dKP_matten.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten.append(shg.get_dRMS(dijk_matten))

    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    list_dKP_matten_masked.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten_masked.append(shg.get_dRMS(dijk_matten))

df_pred_matten["dKP_matten"] = list_dKP_matten
df_pred_matten["dRMS_matten"] = list_dRMS_matten
df_pred_matten["dKP_matten_masked"] = list_dKP_matten_masked
df_pred_matten["dRMS_matten_masked"] = list_dRMS_matten_masked

# Adding the true dKP to the df
df_pred_matten["dKP_true"] = df_orig.filter(df_pred_matten.index, axis=0)[
    "dKP_full_neum"
].tolist()
df_pred_matten["dRMS_true"] = [
    shg.get_dRMS(d)
    for d in df_orig.filter(df_pred_matten.index, axis=0)["dijk_full_neum"].tolist()
]

print(df_pred_matten.shape)
display(df_pred_matten.head())

(125, 7)


Unnamed: 0,dijk_matten,dKP_matten,dRMS_matten,dKP_matten_masked,dRMS_matten_masked,dKP_true,dRMS_true
agm001266143,"[[[-1.55e-08, 36.0535964966, 3.4850635529], [3...",35.160685,18.162448,35.160685,18.162448,43.968531,18.641918
agm002018585,"[[[20.1762714386, -6.65e-08, 0.2993813157], [-...",27.494491,12.652972,27.494491,12.652972,54.691229,27.523965
agm002056302,"[[[-1.999e-07, 9.37e-08, 2.503e-07], [9.37e-08...",7.24955,4.043606,7.24955,4.043606,3.414247,1.904376
agm002072309,"[[[-2.8608e-06, -3.389e-07, 3.040938139], [-3....",4.658129,1.908743,4.658129,1.908743,4.324187,1.894918
agm002073755,"[[[0.0534379818, -0.9033260345, -0.4584548771]...",0.863475,0.433968,0.863475,0.433968,0.778823,0.412738


In [5]:
type_set = os.path.basename(Path(os.getcwd())).split("predict_")[1]

df_lr = pd.read_csv(f"../scripts_{type_set}/lightning_logs/version_0/metrics.csv")
print(df_lr.shape)
display(df_lr.head())

(3001, 13)


Unnamed: 0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
0,50.552261,0,25.883514,,5.364007,74,,,,,705.038757,5.364007,705.038757
1,,0,,,,74,,,702.674683,702.674683,,,
2,76.574265,1,26.022001,,6.227285,149,,,,,1178.956665,6.227285,1178.956665
3,,1,,,,149,,,674.062073,674.062073,,,
4,102.837662,2,26.263401,,5.937873,224,,,,,779.087524,5.937873,779.087524


In [6]:
df_lr_epochs_first = df_lr[df_lr["epoch"].duplicated(keep="last")]
df_lr_epochs_first.index = df_lr_epochs_first["epoch"]
display(df_lr_epochs_first.head())
df_lr_epochs_second = df_lr[df_lr["epoch"].duplicated(keep="first")]
df_lr_epochs_second.index = df_lr_epochs_second["epoch"]
display(df_lr_epochs_second.head())

Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,50.552261,0,25.883514,,5.364007,74,,,,,705.038757,5.364007,705.038757
1,76.574265,1,26.022001,,6.227285,149,,,,,1178.956665,6.227285,1178.956665
2,102.837662,2,26.263401,,5.937873,224,,,,,779.087524,5.937873,779.087524
3,128.968491,3,26.130823,,10.378119,299,,,,,5840.25293,10.378119,5840.25293
4,155.126785,4,26.158298,,5.534164,374,,,,,890.255615,5.534164,890.255615


Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,,0,,,,74,,,702.674683,702.674683,,,
1,,1,,,,149,,,674.062073,674.062073,,,
2,,2,,,,224,,,632.019348,632.019348,,,
3,,3,,,,299,,,510.259064,510.259064,,,
4,,4,,,,374,,,423.123383,423.123383,,,


# Learning rates

## Loss

In [7]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["val/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)
lr_train = go.Scatter(
    x=df_lr_epochs_second.index,
    y=df_lr_epochs_second["train/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Train loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    # yaxis=dict(title='<i>d&#770;</i><sub>KP</sub> (pm/V)', range=[-5,180]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_train, lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

## MAE

In [8]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["metric_val/MeanAbsoluteError/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    yaxis=dict(title="MAE (pm/V)"),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

# Errors on dijk

In [9]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

26.186564379553225


In [10]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

3.00527538248761


## Errors on dijk masked

In [11]:
# masked
list_fronorm = []  # = sqrt(sum(sqrd))
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

26.077471728967023


In [12]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

2.9884827146794555


# Errors on dKP

In [13]:
# Data

mae = mean_absolute_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"]
)
rmse = root_mean_squared_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 6.901184334759853
MAPE = 5.38744431741467
RMSE = 17.61601006369132
Rho_sp = 0.8044731182795697


## Errors on dKP masked

In [14]:
# Data

mae = mean_absolute_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten_masked"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 7.0487615936703865
MAPE = 5.3500630146287955
RMSE = 18.619103894929502
Rho_sp = 0.7435023041474653


# Errors on dRMS

In [15]:
# Data

mae = mean_absolute_error(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
spearmanrho = spearmanr(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dRMS = go.Scatter(
    x=df_pred_matten["dRMS_true"],
    y=df_pred_matten["dRMS_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    yaxis=dict(title="<i>d&#770;</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dRMS, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 3.4497516375080197
MAPE = 5.361874677480903
RMSE = 9.165602414343587
Rho_sp = 0.801216589861751
