In [5]:
from pathlib import Path

from scipy.stats import spearmanr
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    root_mean_squared_error,
)

In [6]:
import os

import numpy as np
import pandas as pd
import plotly.graph_objs as go

In [7]:
import shg_ml_benchmarks.utils_shg as shg
from shg_ml_benchmarks.utils import load_full

In [8]:
df_orig = load_full()
df_pred_matten = pd.read_json("df_pred_matten_holdout.json.gz")

# Computing dKP from the Matten tensors predictions
list_dKP_matten = []
list_dRMS_matten = []
list_dKP_matten_masked = []
list_dRMS_matten_masked = []
for ir, r in df_pred_matten.iterrows():
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    list_dKP_matten.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten.append(shg.get_dRMS(dijk_matten))

    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    list_dKP_matten_masked.append(shg.get_dKP(dijk_matten))
    list_dRMS_matten_masked.append(shg.get_dRMS(dijk_matten))

df_pred_matten["dKP_matten"] = list_dKP_matten
df_pred_matten["dRMS_matten"] = list_dRMS_matten
df_pred_matten["dKP_matten_masked"] = list_dKP_matten_masked
df_pred_matten["dRMS_matten_masked"] = list_dRMS_matten_masked

# Adding the true dKP to the df
df_pred_matten["dKP_true"] = df_orig.filter(df_pred_matten.index, axis=0)[
    "dKP_full_neum"
].tolist()
df_pred_matten["dRMS_true"] = [
    shg.get_dRMS(d)
    for d in df_orig.filter(df_pred_matten.index, axis=0)["dijk_full_neum"].tolist()
]

print(df_pred_matten.shape)
display(df_pred_matten.head())

(250, 7)


Unnamed: 0,dijk_matten,dKP_matten,dRMS_matten,dKP_matten_masked,dRMS_matten_masked,dKP_true,dRMS_true
mp-21892,"[[[-0.8854632378, -2.9232000000000002e-06, -3....",3.373881,1.867279,1e-06,7.705431e-07,1.670288,0.832098
mp-8070,"[[[2.4966099262, 3.4148000000000003e-06, 0.405...",5.91008,2.70685,5.91008,2.70685,0.615357,0.30556
mp-558491,"[[[-0.1729813814, 0.1207283363, 5.18e-08], [0....",0.145567,0.081193,0.145567,0.08119334,0.916306,0.511091
mp-632684,"[[[-1.2744e-06, -1.124e-07, 0.7789599895], [-1...",1.049645,0.438047,1.049645,0.4380471,0.633297,0.261316
mp-23329,"[[[1.2354900000000001e-05, -0.6986798644000001...",9.077246,3.862126,2.971185,1.296609,8.216653,3.472409


In [9]:
type_set = os.path.basename(Path(os.getcwd())).split("predict_")[1]

df_lr = pd.read_csv(f"../scripts_{type_set}/lightning_logs/version_0/metrics.csv")
print(df_lr.shape)
display(df_lr.head())

(1001, 13)


Unnamed: 0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
0,42.650837,0,11.581379,,5.360172,66,,,,,793.076477,5.360172,793.076477
1,,0,,,,66,,,685.362244,685.362244,,,
2,53.455158,1,10.80432,,5.114588,133,,,,,709.228088,5.114588,709.228088
3,,1,,,,133,,,596.966492,596.966492,,,
4,64.224731,2,10.769577,,5.338074,200,,,,,745.423218,5.338074,745.423218


In [10]:
df_lr_epochs_first = df_lr[df_lr["epoch"].duplicated(keep="last")]
df_lr_epochs_first.index = df_lr_epochs_first["epoch"]
display(df_lr_epochs_first.head())
df_lr_epochs_second = df_lr[df_lr["epoch"].duplicated(keep="first")]
df_lr_epochs_second.index = df_lr_epochs_second["epoch"]
display(df_lr_epochs_second.head())

Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,42.650837,0,11.581379,,5.360172,66,,,,,793.076477,5.360172,793.076477
1,53.455158,1,10.80432,,5.114588,133,,,,,709.228088,5.114588,709.228088
2,64.224731,2,10.769577,,5.338074,200,,,,,745.423218,5.338074,745.423218
3,75.429741,3,11.205003,,5.594665,267,,,,,842.742615,5.594665,842.742615
4,86.574913,4,11.145176,,5.942023,334,,,,,1181.420654,5.942023,1181.420654


Unnamed: 0_level_0,cumulative time,epoch,epoch time,metric_test/MeanAbsoluteError/shg_tensor_full,metric_val/MeanAbsoluteError/shg_tensor_full,step,test/loss/shg_tensor_full,test/total_loss,train/loss/shg_tensor_full,train/total_loss,val/loss/shg_tensor_full,val/score,val/total_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,,0,,,,66,,,685.362244,685.362244,,,
1,,1,,,,133,,,596.966492,596.966492,,,
2,,2,,,,200,,,473.835297,473.835297,,,
3,,3,,,,267,,,365.096344,365.096344,,,
4,,4,,,,334,,,275.406647,275.406647,,,


# Learning rates

## Loss

In [11]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["val/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)
lr_train = go.Scatter(
    x=df_lr_epochs_second.index,
    y=df_lr_epochs_second["train/loss/shg_tensor_full"],
    mode="lines+markers",
    name="Train loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    # yaxis=dict(title='<i>d&#770;</i><sub>KP</sub> (pm/V)', range=[-5,180]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_train, lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

## MAE

In [12]:
lr_val = go.Scatter(
    x=df_lr_epochs_first.index,
    y=df_lr_epochs_first["metric_val/MeanAbsoluteError/shg_tensor_full"],
    mode="lines+markers",
    name="Validation loss",
    showlegend=True,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="Epoch"),
    yaxis=dict(title="MAE (pm/V)"),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[lr_val], layout=layout)

fig.update_layout(
    autosize=True,
    font_size=20,
    # width=600,
    # height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

# Errors on dijk

In [13]:
list_fronorm = []  # = sqrt(sum(sqrd))
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

41.14298087854922


In [14]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    if np.count_nonzero(dijk_orig) == 0:
        continue
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

5.21137425144025


## Errors on dijk masked

In [15]:
# masked
list_fronorm = []  # = sqrt(sum(sqrd))
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(np.linalg.norm(t))
print(np.mean(list_fronorm))

39.115518536342854


In [16]:
list_fronorm = []  # = sqrt(sum(sqrd))/nb_nonzero
for ir, r in df_pred_matten.iterrows():
    dijk_orig = shg.from_voigt(df_orig.loc[ir]["dijk_full_neum"])
    dijk_matten = shg.from_voigt(r["dijk_matten"])
    dijk_matten = np.where(dijk_orig != 0, dijk_matten, 0)
    if np.count_nonzero(dijk_orig) == 0:
        continue
    t = dijk_orig - dijk_matten
    list_fronorm.append(
        np.linalg.norm(dijk_orig - dijk_matten) / np.count_nonzero(dijk_orig)
    )
print(np.mean(list_fronorm))

4.884829758387934


# Errors on dKP

In [17]:
# Data

mae = mean_absolute_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"]
)
rmse = root_mean_squared_error(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 9.81547262754275
MAPE = 4.112220246455755
RMSE = 24.328435097010104
Rho_sp = 0.7472471559544952


## Errors on dKP masked

In [18]:
# Data

mae = mean_absolute_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
mape = mean_absolute_percentage_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"]
)
spearmanrho = spearmanr(df_pred_matten["dKP_true"], df_pred_matten["dKP_matten_masked"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dKP = go.Scatter(
    x=df_pred_matten["dKP_true"],
    y=df_pred_matten["dKP_matten_masked"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    yaxis=dict(title="<i>d&#770;</i><sub>KP</sub> (pm/V)", range=[-1, 170]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dKP, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 10.837173443329645
MAPE = 3.484139614437126
RMSE = 26.033160278579714
Rho_sp = 0.6248695098871155


# Errors on dRMS

In [19]:
# Data

mae = mean_absolute_error(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
mape = mean_absolute_percentage_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
rmse = root_mean_squared_error(
    df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"]
)
spearmanrho = spearmanr(df_pred_matten["dRMS_true"], df_pred_matten["dRMS_matten"])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_dRMS = go.Scatter(
    x=df_pred_matten["dRMS_true"],
    y=df_pred_matten["dRMS_matten"],
    mode="markers",
    name="",
    showlegend=False,
    text=[mpid for mpid in df_pred_matten.index.values],
)

ideal = go.Scatter(
    x=[-1, 200],
    y=[-1, 200],
    mode="lines",
    line=dict(color="gray", dash="dot"),
    showlegend=False,
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title="<i>d</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    yaxis=dict(title="<i>d&#770;</i><sub>RMS</sub> (pm/V)", range=[-1, 100]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_dRMS, ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template="simple_white",
)
fig.update_layout(
    xaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis=dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)

# Show figure
fig.show()

MAE = 5.14624876759226
MAPE = 4.177181910650985
RMSE = 12.586826286565612
Rho_sp = 0.7449400470407526
