# **SB3 model explainability with SHAP values**

First, we load the required libraries and specify the model and datasets paths.

In [None]:
import shap
import torch

import numpy as np

from pandas import read_csv

from stable_baselines3 import SAC

MODEL_PATH = './model.zip'
MONITOR_PATH = './monitor_norm.csv'
PROGRESS_PATH = './progress.csv'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The SB3 agent is loaded. Every observation is composed by 82 floating point values, while actions are arrays of 6 values in `[-1,1]`.

In [None]:
agent = SAC.load(MODEL_PATH)

print(f'Observation space: {agent.observation_space}')
print(f'Action space: {agent.action_space}')

We load the dataset and pre-visualize it.

In [None]:
monitor = read_csv(MONITOR_PATH)
monitor.head()

Those columns/variables not perceived by the agent during training are removed. Rows with `NaN` are also dropped.

In [None]:
to_remove = ['timestep', 'reward', 'reward_energy_term', 'reward_comfort_term', 'time (hours)',
             'absolute_energy_penalty', 'absolute_comfort_penalty', 'terminated',
             'truncated', 'total_temperature_violation', 'flow_livroom', 'total_power_demand',
             'flow_kitchen', 'flow_bed1', 'flow_bed2', 'flow_bed3', 'water_temperature']

data = monitor.drop(columns=to_remove).dropna()

data.head()

Once the dataset and the model are both ready, we test them by doing some sample predictions.

In [None]:
def model_predict(data, model): 
    data_tensor = torch.tensor(data.values, dtype=torch.float32).to(device)
    with torch.no_grad():
        actions = model(data_tensor)
    return actions.cpu()

model_predict(data[:5], agent.policy)

Now we will calculate the corresponding SHAP values. 

The `SACPolicyWrapper` class redefines the model, providing a single-vector output instead of a tuple.

We use a `KernelExplainer`, which is quite robust and model-agnostic.



In [None]:
SHAP_SAMPLES = 500

# def to_hash(action):
#     hashes = []
#     for a in action:
#         hashes.append(hash(tuple(a)))
#     return np.array(hashes)

class SACPolicyWrapper:
    def __init__(self, model):
        self.model = model
    
    def predict(self, obs):
        action, _ = self.model.predict(obs, deterministic=True)
        return action
        # return np.mean(action, axis=1) # mean
        # return np.linalg.norm(action, axis=1) # L2 norm
        # return to_hash(action)
    

policy_wrapper = SACPolicyWrapper(agent)

def sac_policy_predict(obs):
    return policy_wrapper.predict(obs)


X = data.sample(n=SHAP_SAMPLES, replace=False)

explainer = shap.KernelExplainer(sac_policy_predict, X)

shap_values = explainer.shap_values(X)

Once the explainer is fit and the shap values are computed, we create an `Explanation` object and plot the results...

In [None]:
# explanation = shap.Explanation(
#     values=shap_values[:,:,1],
#     base_values=explainer.expected_value[1],
#     data=X,
#     feature_names=features
# )

# shap.plots.beeswarm(explanation.abs, max_display=len(features), color='shap_red')
# shap.plots.beeswarm(explanation.abs, color='shap_red')

# shap.plots.beeswarm(explanation, max_display=len(features))
# shap.plots.beeswarm(explanation)


In [None]:
features = X.columns.tolist()

mean_shap_values = np.mean(shap_values, axis=2)

explanation = shap.Explanation(
    values=mean_shap_values,
    base_values=np.mean(explainer.expected_value),
    data=X,
    feature_names=features
)

shap.plots.beeswarm(explanation, max_display=len(features))
