# Model Explainability

## Introduction

This notebook loads evaluation results from the `/artifacts` directory and generates SHAP explainability plots for the model developed in  
J. Albert-Smet, Z. Frias, L. Mendo, S. Melones, and E. Yraola, *“Characterizing 5G User Throughput via Uncertainty Modeling and Crowdsourced Measurements,”* arXiv preprint arXiv:2510.09239, Oct. 2025.

The goal is to analyze feature importance and model behavior using SHAP values, providing a transparent interpretation of how different network parameters contribute to 5G user throughput predictions.

This notebook is organized as follows:

**1. Prepare Data**  
   Load configuration files, SHAP values, and feature definitions required for analysis.

**2. Visualize Plots**  
   **2.1. Normalized Mean Absolute SHAP Values** – Summarizes standard SHAP beeswarm plots by computing the mean of absolute SHAP values. This provides a high-level comparison of feature influence from 4G to 5G SA.  
   **2.2. Feature Value vs. SHAP Values** – Compares SHAP values with the corresponding feature values, showing both the magnitude and direction of influence. This is illustrated for *Frequency Band*, *RSRP*, and *DL TTFB*, the three most influential features in the 5G SA scenario. The spread of SHAP values also reflects NGBoost’s ability to produce more stable central estimates.  
   **2.3. Feature Importance Ratios** – Quantifies the evolution of feature relevance from 4G to 5G SA, complementing the qualitative observations from section 2.1.


## 1. Prepare Data

### Imports

In [None]:
# Standard library imports
import joblib

# Third-party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm

# Local imports
from char5g.utils import (
    get_latest_experiment_dir,
    get_project_root,
    load_config,
    categorize_features,
)

### Define Experiment Directory

In [None]:
# Option 1: Get results from latest experiment
project_root = get_project_root()
exp_dir = get_latest_experiment_dir(project_root / "artifacts")
config = load_config(exp_dir / "config.yaml")

# Option 2: Provide experiment directory and config file
# exp_dir = ".../..."
# config = load_config(".../...")

### Extract and Categorize Features

In [None]:
features = config["data"]["features"] + config["data"]["temporal_encodings"]
features_by_category = categorize_features(features)

deployment = features_by_category["deployment"]
radio = features_by_category["radio"]
e2e = features_by_category["e2e"]
context = features_by_category["context"]

features_order = radio + e2e + context + deployment

### Load SHAP values and Modeling Data

In [45]:
shap_values_test = joblib.load(exp_dir / "eval" / "test_shap_values.pkl")
Xy_train_val_test = joblib.load(exp_dir / "data_splits" / "Xy_train_val_test.pkl")

### Create DataFrame: Mean Absolute SHAP Values per Model per Technology

In [None]:
# Initialize a DataFrame to store mean absolute SHAP values
columns = [
    f"{tech}_{model}"
    for tech in shap_values_test.keys()
    for model in shap_values_test[tech].keys()
]

shap_avg_df = pd.DataFrame(index=features, columns=columns, dtype=float)

In [None]:
# Compute mean absolute SHAP values for each (technology, model) pair
for tech, models in shap_values_test.items():
    for model_name, shap_vals in models.items():
        col_name = f"{tech}_{model_name}"
        shap_avg_df[col_name] = np.abs(shap_vals.values).mean(axis=0)

In [None]:
# Compute the normalized values to enable comparison between the models
df_norm = shap_avg_df.div(shap_avg_df.sum(axis=0), axis=1)

## 2. Visualize Plots

### Mean Absolute SHAP Values per Model per Technology

In [None]:
# Assign colors by feature category for consistent plotting
category_colors = {
    "deployment": cm.Greys,
    "radio": cm.Blues,
    "e2e": cm.Reds,
    "context": cm.Greens,
}

# Build the color map
color_map = {}
for category, cmap in category_colors.items():
    feats = features_by_category[category]
    color_map.update({
        f: c for f, c in zip(feats, cmap(np.linspace(0.4, 0.7, len(feats))))
    })

In [None]:
fig, ax = plt.subplots(figsize=(8, 4.5), dpi=150)

bottom_vals = np.zeros(len(df_norm.columns))
for feature in df_norm.index:
    ax.bar(df_norm.columns, df_norm.loc[feature],
           bottom=bottom_vals,
           label=feature,
           color=color_map[feature],   # same mapping as before
           edgecolor="black", linewidth=0.5,
           width=0.5)
    bottom_vals += df_norm.loc[feature].values

ax.axvline(x=2.5, color="black", linestyle="-", linewidth=1)  # between 3rd and 4th ticks
ax.axvline(x=5.5, color="black", linestyle="-", linewidth=1)  # between 6th and 7th ticks

ax.axvspan(-0.5, 0.5, facecolor="none", edgecolor="orange",
           hatch="///", alpha=0.6, linewidth=0.0, zorder=0)

ax.axvspan(0.5, 2.5, facecolor="none", edgecolor="blue",
           hatch="///", alpha=0.6, linewidth=0.0, zorder=0)

ax.axvspan(2.5, 3.5, facecolor="none", edgecolor="orange",
           hatch="///", alpha=0.6, linewidth=0.0, zorder=0)

ax.axvspan(3.5, 5.5, facecolor="none", edgecolor="blue",
           hatch="///", alpha=0.6, linewidth=0.0, zorder=0)

ax.axvspan(5.5, 6.5, facecolor="none", edgecolor="orange",
           hatch="///", alpha=0.6, linewidth=0.0, zorder=0)

ax.axvspan(6.5, 8.5, facecolor="none", edgecolor="blue",
           hatch="///", alpha=0.6, linewidth=0.0, zorder=0)

ax.set_ylabel("Normalized Importance")
ax.legend(title="Feature", bbox_to_anchor=(1.05, 1), loc="upper left")
ax.grid(axis="y", linestyle="--", alpha=0.5)
ax.set_ylim(0, 1)

tick_labels = (["XGB", r"NGB $\mu$", r"NGB $\sigma$"] * 3)  # repeat 3 times
ax.set_xticks(np.arange(len(df_norm.columns)))
ax.set_xticklabels(tick_labels, rotation=0, fontsize=9)

ax.tick_params(axis="x", which="major", pad=5)  

group_labels = ["4G", "5G_NSA", "5G_SA"]
group_sizes  = [3, 3, 3]  

pos = np.arange(len(df_norm.columns))
start = 0
for label, size in zip(group_labels, group_sizes):
    center = start + (size-1)/2
    ax.text(center, -0.08, label, ha="center", va="top",
            transform=ax.get_xaxis_transform(),
            fontsize=11)
    start += size

h_xgb = mpatches.Patch(facecolor="none", edgecolor="orange",
                       hatch="///", label="XGBoost", linewidth=0.0)
h_ngb = mpatches.Patch(facecolor="none", edgecolor="blue",
                       hatch="///", label="NGBoost", linewidth=0.0)

feat_legend = ax.legend(title="Feature", bbox_to_anchor=(1, 1), loc="upper left")

model_legend = ax.legend(handles=[h_xgb, h_ngb], title="Model encoding",
                         bbox_to_anchor=(1, 0.2), loc="upper left", frameon=True)
ax.add_artist(feat_legend)

plt.tight_layout()
plt.show()

[View  Fig. 3. Average feature importance plot for the XGBoost and NGBoost models
for 4G, 5G NSA and 5G SA. (PDF)](Fig3.Feature_importance.pdf)

### Feature Values vs SHAP Values for 5G SA

In [None]:
tech = "5G_SA"
fontsize = 12

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 3), sharey=True, constrained_layout=True)

X_shap = Xy_train_val_test[tech]["X_test"]
xgb_shap = shap_values_test[tech]["XGBoost"]
ngboost_shap_mean = shap_values_test[tech]["NGBoost_mean"]

for idx, feature_name in enumerate(["Frequency Band", "RSRP", "DL TTFB"]):

    feature_idx_shap = features.index(feature_name)
    feature_idx_x = features.index(feature_name)

    axs[idx].axhline(0, color="gray", linestyle="--", alpha=0.7, label="Zero SHAP")

    axs[idx].scatter(X_shap[:, feature_idx_x], xgb_shap.values[:, feature_idx_shap],
                color="blue", alpha=0.5, label="XGBoost")
    axs[idx].scatter(X_shap[:, feature_idx_x], ngboost_shap_mean.values[:, feature_idx_shap],
                color="orange", alpha=0.5, label=r"NGBoost $\mu$")

    # axis labels
    if feature_name == "RSRP":
        axs[idx].set_xlabel(feature_name + " (dBm)", fontsize=fontsize)
    elif feature_name == "Frequency Band":
        axs[idx].set_xlabel(feature_name + " (MHz)", fontsize=fontsize)
    elif feature_name == "DL TTFB":
        axs[idx].set_xlabel(feature_name + " (ms)", fontsize=fontsize)

    axs[idx].set_ylim([-2.5, 1.5])

    if feature_name == "DL TTFB":
        axs[idx].set_xscale("log")
        axs[idx].set_xlim([75, 1.5e3])

    axs[idx].grid(True, linestyle="--", alpha=0.6)

    # tick labels larger
    axs[idx].tick_params(axis="both", labelsize=fontsize)

# y-axis label
axs[0].set_ylabel("SHAP value", fontsize=fontsize)

# legend with larger font
axs[2].legend(fontsize=fontsize - 2)

plt.show()


[Fig. 4. SHAP vs feature value plots of XGBoost and NGBoost for the three
most important inputs in 5G SA networks. (PDF)](Fig4.SHAP_spread.pdf)

### Ratio of Feature and Feature Groups SHAP Values 

In [None]:
# Compare total normalized SHAP importance between E2E and Radio features
e2e_vs_radio_ratio = df_norm.loc[e2e].sum(axis=0) / df_norm.loc[radio].sum(axis=0)

print("Ratio of total normalized SHAP importance between E2E and Radio feature groups:")
print(e2e_vs_radio_ratio)

In [None]:
# Compute relative importance of RSRQ compared to total Radio features
rsrq_ratio = df_norm.loc["RSRQ"] / df_norm.loc[radio].sum(axis=0)

print("RSRQ relative contribution (vs. total radio features):")
print(rsrq_ratio)

In [None]:
# Compute relative importance of RSRQ compared to total Radio features
rsrp_ratio = df_norm.loc["RSRP"] / df_norm.loc[radio].sum(axis=0)

print("RSRP relative contribution (vs. total radio features):")
print(rsrp_ratio)