# Step 3b -- Evaluate / Plot trained model

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import joblib

import config

In [None]:
df_scores = pd.read_csv(config.OUT_DIR / "test_scores.csv", index_col=0)
df_scores

In [None]:
df_test_pred = pd.read_csv(config.OUT_DIR / "test_predictions.csv", index_col=0, parse_dates=True)
df_test_pred

In [None]:
# Scatter plot
fig, ax = plt.subplots()
sns.scatterplot(df_test_pred, x="Cn2_true", y="Cn2_pred", ax=ax)

In [None]:
## QQ plot
fig, ax = plt.subplots()
x = np.sort(df_test_pred["Cn2_true"].values)
y = np.sort(df_test_pred["Cn2_pred"].values)
ax.scatter(x, y, s=2)
ax.plot(x[[0, -1]], y[[0, -1]], color="black", linestyle="--")

In [None]:
## Randomly selected weeks as example
n = 20
days = df_test_pred.index.round("D").unique()
days = np.random.choice(days, n, replace=False)
dt = pd.Timedelta("7D")

n_cols = 4
n_rows = n // n_cols
fig, axarr = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows), sharey="all")

for ax, day_start in zip(axarr.flat, days):
    df_i = df_test_pred.loc[day_start:day_start + dt]
    ax.plot(df_i.index, df_i["Cn2_true"], label="True")
    ax.plot(df_i.index, df_i["Cn2_pred"], label="Pred")


In [None]:
df_fi = pd.read_csv(config.OUT_DIR / "shap_fi.csv", index_col=0)
df_fi

In [None]:
fig, ax = plt.subplots(figsize=(3, 8))
df_fi.plot.barh(ax=ax)

In [None]:
era5_grouped = {
    "rad": [
        "msdrswrf", "msdwlwrf", "msdwswrf", "msnlwrf", "msnswrf"
    ],
    "hr": ["sin_hr", "cos_hr"],
    "day": ["sin_day", "cos_day"],
    "month": ["sin_month", "cos_month"],
    "cloud": ["lcc", "tcc"],
    "X": ["sin_X10", "sin_X100", "cos_X10", "cos_X100"],
    "dT": ["dT0", "dT1"],
}
others = [c for c in df_fi.index if c not in np.concatenate(list(era5_grouped.values()))]

fi_grouped = {k: df_fi.loc[v].sum() for k, v in era5_grouped.items()}

df_fi_grouped = pd.DataFrame(fi_grouped).T
df_fi_grouped = pd.concat([df_fi_grouped, df_fi.loc[others]])
df_fi_grouped = df_fi_grouped.sort_values(by="fi_shap", ascending=False)
df_fi_grouped

In [None]:
fig, ax = plt.subplots(figsize=(3, 8))
df_fi_grouped.plot.barh(ax=ax)