In [1]:
from pathlib import Path

from scipy.stats import spearmanr
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    root_mean_squared_error,
)

In [2]:
import os

import numpy as np
import pandas as pd
import plotly.graph_objs as go

In [3]:
import shg_ml_benchmarks.utils_shg as shg
from shg_ml_benchmarks.utils import load_full

In [4]:
df_orig = load_full()
df_pred_matten = pd.read_json("df_pred_matten_holdout.json.gz")

# Computing dKP from the Matten tensors predictions
list_dKP_matten = []
list_dRMS_matten = []
list_dKP_matten_masked = []
list_dRMS_matten_masked = []
for ir, r in df_pred_matten.iterrows():
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    list_dKP_matten.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten.append(shg.get_dRMS(dijk_matten))

    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    list_dKP_matten_masked.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten_masked.append(shg.get_dRMS(dijk_matten))

df_pred_matten["dKP_matten"] = list_dKP_matten
df_pred_matten["dRMS_matten"] = list_dRMS_matten
df_pred_matten["dKP_matten_masked"] = list_dKP_matten_masked
df_pred_matten["dRMS_matten_masked"] = list_dRMS_matten_masked

# Adding the true dKP to the df
df_pred_matten["dKP_true"] = df_orig.filter(df_pred_matten.index, axis=0)[
    "dKP_full_neum"
].tolist()
df_pred_matten["dRMS_true"] = [
    shg.get_dRMS(d)
    for d in df_orig.filter(df_pred_matten.index, axis=0)["dijk_full_neum"].tolist()
]

print(df_pred_matten.shape)
display(df_pred_matten.head())

(125, 7)


Unnamed: 0,dijk_matten,dKP_matten,dRMS_matten,dKP_matten_masked,dRMS_matten_masked,dKP_true,dRMS_true
agm001234439,"[[[4.411775589, 0.0310529899, 0.1241508275], [...",2.25072,1.25295,2.25072,1.25295,0.827288,0.453301
agm001375463,"[[[0.0001351187, 6.60228e-05, 6.61535e-05], [6...",180.070511,100.438532,180.070511,100.438532,29.619974,16.521232
agm002017294,"[[[-75.7248458862, 4.6550000000000003e-07, -0....",114.496833,50.945691,114.496833,50.945691,16.934658,9.124151
agm002041315,"[[[12.4770307541, 13.4259977341, -10.479684829...",13.520093,7.481015,13.520093,7.481015,58.140328,28.108656
agm002041336,"[[[-1.6182515621, -1.3291000000000001e-06, -7....",6.710461,3.742071,6.710461,3.742071,42.717236,21.462845


In [5]:
type_set = os.path.basename(Path(os.getcwd())).split("predict_")[1]

df_lr = pd.read_csv(f"../scripts_{type_set}/lightning_logs/version_0/metrics.csv")
print(df_lr.shape)
display(df_lr.head())

(3001, 13)


Unnamed: 0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
0,62.134197,0,35.851181,,6.513093,78,,,,,914.463623,6.513093,914.463623
1,,0,,,,78,,,690.202209,690.202209,,,
2,98.240906,1,36.106709,,7.479188,157,,,,,1181.090942,7.479188,1181.090942
3,,1,,,,157,,,653.265991,653.265991,,,
4,134.354065,2,36.113152,,10.523038,236,,,,,4312.287109,10.523038,4312.287109


In [6]:
df_lr_epochs_first = df_lr[df_lr["epoch"].duplicated(keep="last")]
df_lr_epochs_first.index = df_lr_epochs_first["epoch"]
display(df_lr_epochs_first.head())
df_lr_epochs_second = df_lr[df_lr["epoch"].duplicated(keep="first")]
df_lr_epochs_second.index = df_lr_epochs_second["epoch"]
display(df_lr_epochs_second.head())

Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,62.134197,0,35.851181,,6.513093,78,,,,,914.463623,6.513093,914.463623
1,98.240906,1,36.106709,,7.479188,157,,,,,1181.090942,7.479188,1181.090942
2,134.354065,2,36.113152,,10.523038,236,,,,,4312.287109,10.523038,4312.287109
3,170.479858,3,36.125805,,8.73436,315,,,,,2768.166504,8.73436,2768.166504
4,206.67334,4,36.193474,,6.644928,394,,,,,1038.801514,6.644928,1038.801514


Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,,0,,,,78,,,690.202209,690.202209,,,
1,,1,,,,157,,,653.265991,653.265991,,,
2,,2,,,,236,,,563.309204,563.309204,,,
3,,3,,,,315,,,481.513153,481.513153,,,
4,,4,,,,394,,,428.951813,428.951813,,,


# Learning rates

## Loss

In [7]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["val/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)
lr_train = go.Scatter(
    x=df_lr_epochs_second.index,
    y=df_lr_epochs_second["train/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Train loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    # yaxis=dict(title='<i>d&#770;</i><sub>KP</sub> (pm/V)', range=[-5,180]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_train, lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

## MAE

In [8]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["metric_val/MeanAbsoluteError/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    yaxis=dict(title="MAE (pm/V)"),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

# Errors on dijk

In [9]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

50.35855147221536


In [10]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

6.906199027537919


## Errors on dijk masked

In [11]:
# masked
list_fronorm = []  # = sqrt(sum(sqrd))
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

47.35347869353275


In [12]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

6.421481040687913


# Errors on dKP

In [13]:
# Data

mae = mean_absolute_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"]
)
rmse = root_mean_squared_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 11.498741813088591
MAPE = 2.135298900037838
RMSE = 29.017833951511584
Rho_sp = 0.8107281105990781


## Errors on dKP masked

In [14]:
# Data

mae = mean_absolute_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten_masked"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 13.19834674961361
MAPE = 2.1341470053197114
RMSE = 31.56970930317157
Rho_sp = 0.6489646697388631


# Errors on dRMS

In [15]:
# Data

mae = mean_absolute_error(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
spearmanrho = spearmanr(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dRMS = go.Scatter(
    x=df_pred_matten["dRMS_true"],
    y=df_pred_matten["dRMS_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    yaxis=dict(title="<i>d&#770;</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dRMS, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 5.99844856749669
MAPE = 2.1371670649825503
RMSE = 15.451509651682379
Rho_sp = 0.8070046082949308
