In [1]:
from pathlib import Path

from scipy.stats import spearmanr
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    root_mean_squared_error,
)

In [2]:
import os

import numpy as np
import pandas as pd
import plotly.graph_objs as go

In [3]:
import shg_ml_benchmarks.utils_shg as shg
from shg_ml_benchmarks.utils import load_full

In [4]:
df_orig = load_full()
df_pred_matten = pd.read_json("df_pred_matten_holdout.json.gz")

# Computing dKP from the Matten tensors predictions
list_dKP_matten = []
list_dRMS_matten = []
list_dKP_matten_masked = []
list_dRMS_matten_masked = []
for ir, r in df_pred_matten.iterrows():
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    list_dKP_matten.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten.append(shg.get_dRMS(dijk_matten))

    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    list_dKP_matten_masked.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten_masked.append(shg.get_dRMS(dijk_matten))

df_pred_matten["dKP_matten"] = list_dKP_matten
df_pred_matten["dRMS_matten"] = list_dRMS_matten
df_pred_matten["dKP_matten_masked"] = list_dKP_matten_masked
df_pred_matten["dRMS_matten_masked"] = list_dRMS_matten_masked

# Adding the true dKP to the df
df_pred_matten["dKP_true"] = df_orig.filter(df_pred_matten.index, axis=0)[
    "dKP_full_neum"
].tolist()
df_pred_matten["dRMS_true"] = [
    shg.get_dRMS(d)
    for d in df_orig.filter(df_pred_matten.index, axis=0)["dijk_full_neum"].tolist()
]

print(df_pred_matten.shape)
display(df_pred_matten.head())

(250, 7)


Unnamed: 0,dijk_matten,dKP_matten,dRMS_matten,dKP_matten_masked,dRMS_matten_masked,dKP_true,dRMS_true
agm001234439,"[[[2.3539202213, -0.0060851327, 0.125241816], ...",2.405812,1.02376,2.405812,1.02376,0.827288,0.453301
agm001347560,"[[[-3.55328e-05, -7.4076e-06, -1.04688e-05], [...",202.858416,113.149018,202.858416,113.149018,119.177507,66.474038
agm002041688,"[[[15.9056596756, 4.3019003868, -3.7186942101]...",14.68591,6.796376,14.68591,6.796376,5.595534,2.876049
agm002054632,"[[[1.4559764862, 1.8480387926000001, -1.225272...",23.336928,10.84069,23.336928,10.84069,45.303136,21.07202
agm002072314,"[[[9.235e-07, -1.6690000000000002e-07, -2.0766...",1.473536,0.718447,1.473536,0.718447,3.674658,1.610681


In [5]:
type_set = os.path.basename(Path(os.getcwd())).split("predict_")[1]

df_lr = pd.read_csv(f"../scripts_{type_set}/lightning_logs/version_0/metrics.csv")
print(df_lr.shape)
display(df_lr.head())

(3001, 13)


Unnamed: 0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
0,60.91917,0,34.043652,,5.435301,74,,,,,600.164001,5.435301,600.164001
1,,0,,,,74,,,712.227295,712.227295,,,
2,95.651917,1,34.73275,,5.401658,149,,,,,582.088501,5.401658,582.088501
3,,1,,,,149,,,673.345154,673.345154,,,
4,130.24411,2,34.592194,,6.523003,224,,,,,1121.05603,6.523003,1121.05603


In [6]:
df_lr_epochs_first = df_lr[df_lr["epoch"].duplicated(keep="last")]
df_lr_epochs_first.index = df_lr_epochs_first["epoch"]
display(df_lr_epochs_first.head())
df_lr_epochs_second = df_lr[df_lr["epoch"].duplicated(keep="first")]
df_lr_epochs_second.index = df_lr_epochs_second["epoch"]
display(df_lr_epochs_second.head())

Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,60.91917,0,34.043652,,5.435301,74,,,,,600.164001,5.435301,600.164001
1,95.651917,1,34.73275,,5.401658,149,,,,,582.088501,5.401658,582.088501
2,130.24411,2,34.592194,,6.523003,224,,,,,1121.05603,6.523003,1121.05603
3,164.820572,3,34.576462,,5.705935,299,,,,,681.567505,5.705935,681.567505
4,199.820312,4,34.999733,,5.48914,374,,,,,644.286133,5.48914,644.286133


Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,,0,,,,74,,,712.227295,712.227295,,,
1,,1,,,,149,,,673.345154,673.345154,,,
2,,2,,,,224,,,580.091187,580.091187,,,
3,,3,,,,299,,,475.042725,475.042725,,,
4,,4,,,,374,,,419.99176,419.99176,,,


# Learning rates

## Loss

In [7]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["val/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)
lr_train = go.Scatter(
    x=df_lr_epochs_second.index,
    y=df_lr_epochs_second["train/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Train loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    # yaxis=dict(title='<i>d&#770;</i><sub>KP</sub> (pm/V)', range=[-5,180]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_train, lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

## MAE

In [8]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["metric_val/MeanAbsoluteError/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    yaxis=dict(title="MAE (pm/V)"),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

# Errors on dijk

In [9]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

32.481885664329965


In [10]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

3.831095507218418


## Errors on dijk masked

In [11]:
# masked
list_fronorm = []  # = sqrt(sum(sqrd))
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

30.95705399523474


In [12]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

3.5923375570450076


# Errors on dKP

In [13]:
# Data

mae = mean_absolute_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"]
)
rmse = root_mean_squared_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 8.212044033901746
MAPE = 6.180294580849674
RMSE = 20.6763371639879
Rho_sp = 0.7988975183602938


## Errors on dKP masked

In [14]:
# Data

mae = mean_absolute_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten_masked"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 8.644536396996326
MAPE = 5.907699199752137
RMSE = 21.71973355815341
Rho_sp = 0.661347285556569


# Errors on dRMS

In [15]:
# Data

mae = mean_absolute_error(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
spearmanrho = spearmanr(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dRMS = go.Scatter(
    x=df_pred_matten["dRMS_true"],
    y=df_pred_matten["dRMS_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    yaxis=dict(title="<i>d&#770;</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dRMS, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 4.267046665197005
MAPE = 6.213331935219069
RMSE = 10.703665353730786
Rho_sp = 0.7997323477175634
