# WQMG Model Evaluation
This notebook evaluates models using metrics and KDE plots.

In [1]:
# Import libraries for metrics and visualization
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns

## Plot Configuration
Set Matplotlib parameters.

In [2]:
# Configure Matplotlib parameters for consistent plot styling
plt.rcParams.update({
    'font.size': 16,              # Set default font size
    'axes.labelsize': 16,         # Set axis label font size
    'xtick.labelsize': 16,        # Set x-tick label font size
    'ytick.labelsize': 16,        # Set y-tick label font size
    'axes.linewidth': 1.2,        # Set axis line width
    'xtick.major.width': 1.2,     # Set x-tick line width
    'ytick.major.width': 1.2,     # Set y-tick line width
    'savefig.format': 'png',      # Save figures in PNG format
    'savefig.bbox': 'tight',      # Use tight layout for saved figures
    'font.family': 'Times New Roman',  # Set font family
    'text.usetex': False          # Disable LaTeX rendering for text
})

## Load Model Predictions
Load prediction data for different models.

In [3]:
# Load LSTM predictions
lstm = pd.read_csv("../models/lstm/lstm_predictions.csv")
lstm.head()

Unnamed: 0,DATE,EV,WQEV,WQMGEV
0,8/23/2016,6.562248,7.432892,6.745247
1,9/12/2016,5.121693,6.663954,5.676876
2,12/1/2016,1.333822,5.461903,5.029943
3,12/21/2016,1.366752,1.364518,1.162341
4,12/31/2016,1.140757,1.666799,1.498404


In [4]:
# Load BO-LSTM predictions
bolstm = pd.read_csv("../models/bolstm/bolstm_predictions.csv")
bolstm.head()

Unnamed: 0,DATE,EV,WQEV,WQMGEV
0,8/23/2016,6.562248,7.258184,6.664568
1,9/12/2016,5.121693,5.655594,4.298963
2,12/1/2016,1.333822,4.071841,1.929978
3,12/21/2016,1.366752,0.763052,1.170905
4,12/31/2016,1.140757,1.328744,1.263597


In [5]:
# Load GRU predictions
gru = pd.read_csv("../models/gru/gru_predictions.csv")
gru.head()

Unnamed: 0,DATE,EV,WQEV,WQMGEV
0,8/23/2016,6.562248,7.48439,6.846831
1,9/12/2016,5.121693,6.510322,5.786075
2,12/1/2016,1.333822,5.297688,4.812537
3,12/21/2016,1.366752,1.428586,0.991425
4,12/31/2016,1.140757,1.455441,1.176014


In [6]:
# Load BO-GRU predictions
bogru = pd.read_csv("../models/bogru/bogru_predictions.csv")
bogru.head()

Unnamed: 0,DATE,EV,WQEV,WQMGEV
0,8/23/2016,6.562248,7.268462,6.853868
1,9/12/2016,5.121693,6.33585,5.538174
2,12/1/2016,1.333822,5.433517,4.224401
3,12/21/2016,1.366752,1.245891,1.249798
4,12/31/2016,1.140757,0.875852,1.777381


## Metric Calculation Functions
Define functions for evaluation metrics.

In [7]:
# Function to calculate mean bias error
def mean_bias_error(y_obs, y_pred):
    return np.mean(y_pred - y_obs)

In [8]:
# Function to compute all metrics
def metrics(data, y):
    obs = data["EV"]
    pred = data[y]
    # Metrics
    mae = mean_absolute_error(obs, pred)
    rmse = np.sqrt(mean_squared_error(obs, pred))
    mbe = mean_bias_error(obs, pred)
    r2 = r2_score(obs, pred)
    
    return r2, mae, rmse, mbe

## Compute Model Metrics
Calculate metrics for each model.

In [9]:
# Compute metrics for each model using WQMGEV
lstm_metrics = metrics(lstm, "WQMGEV")
bolstm_metrics = metrics(bolstm, "WQMGEV")
gru_metrics = metrics(gru, "WQMGEV")
bogru_metrics = metrics(bogru, "WQMGEV")

## Prepare Data for Plotting
Extract predictions for plotting.

In [10]:
# Extract observed EV
y = lstm['EV']
# Extract predictions
x1 = lstm['WQMGEV']
x2 = bolstm['WQMGEV']
x3 = gru['WQMGEV']
x4 = bogru['WQMGEV']

## Model List
Define list of models with metrics and labels.

In [11]:
# List of models, metrics, and labels
models = [
    (x1, lstm_metrics, 'LSTM'),
    (x2, bolstm_metrics, 'BO-LSTM'),
    (x3, gru_metrics, 'GRU'),
    (x4, bogru_metrics, 'BO-GRU')
]

## Generate Evaluation Plots
Create KDE and scatter plots for each model.

In [None]:
# Generate plots for each model
for model, m, label in models:
    fig, ax = plt.subplots(figsize=(7, 6), dpi=600)

    # Plot KDE
    sns.kdeplot(x=model, y=y, levels=10, cmap='CMRmap_r', alpha=0.6, fill=True, ax=ax)
    # Scatter plot
    ax.scatter(y, model, color='deepskyblue', s=60, alpha=0.8, edgecolor='black')
    # 1:1 line
    ax.plot([0, 11], [0, 11], color='red', linestyle='--', linewidth=3)
    ax.scatter([0], [0], label=label, c='white', s=0)

    # Labels
    ax.set_xlabel(r"EV$_{obs}$ (mm/day)", fontsize=18, labelpad=15)
    ax.set_ylabel(r"EV$_{pred}$ (mm/day)", fontsize=18, labelpad=15)

    # Legend
    ax.legend(loc='upper left', fontsize=18, frameon=False, handletextpad=0.0, handlelength=0)

    # Annotate metrics
    ax.text(6, 0.5,
            f"R² = {m[0]:.2f}\nMAE = {m[1]:.2f} (mm/day)\nRMSE = {m[2]:.2f} (mm/day)\nMBE = {m[3]:.2f} (mm/day)",
            fontsize=16, color='black')

    # Set axes limits and ticks
    ax.set_ylim(-0.00, 11)
    ax.set_yticks(np.arange(-0.00, 11, 1))
    ax.set_xlim(-0.00, 11)
    ax.set_xticks(np.arange(-0.00, 11, 1))

    ax.tick_params(axis='both', which='major', labelsize=18)

    # Save plot
    plt.tight_layout()
    plt.savefig(f"./../plots/wqmg-{label.lower()}.png", dpi=600, bbox_inches='tight')

    plt.close()

# plt.show()