In [2]:
import numpy as np
import shap
import xgboost as xgb
import matplotlib.pyplot as plt
import onnxruntime as ort
import json

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# === CONFIG ===
BIN_PATH = "AgentObservations.bin"
ONNX_PATH = "Pyramids2_with_logits.onnx"
OBS_SIZE = 172
NUM_ACTIONS = 5
ACTIONS_PATH = "AgentActions.bin"
SHAP_JSON_PATH = "semantic_shap.json"

# === TAG DEFINITIONS ===
TagLabels = ["Block", "Wall", "Goal", "SwitchOff", "SwitchOn", "Stone"]

In [4]:
# === LOAD OBSERVATIONS ===
def load_observations(path):
    data = np.fromfile(path, dtype=np.float32)
    remainder = data.size % OBS_SIZE
    if remainder != 0:
        print(f"Trimming {remainder} extra floats")
        data = data[:-remainder]
    reshaped = data.reshape(-1, OBS_SIZE)
    mask = np.isfinite(reshaped).all(axis=1)
    return reshaped[mask]

X = load_observations(BIN_PATH)
print(f"Loaded {X.shape[0]} valid observations")

Loaded 12696 valid observations


In [5]:
# === SEMANTIC GROUPING ===
SemanticGroups = {f"{tag} Detected": [] for tag in TagLabels}
SemanticGroups["Misses"] = []
SemanticGroups["Proximity"] = []
SemanticGroups["SwitchState"] = [168]
SemanticGroups["Velocity"] = [169, 170, 171]

for sensor_i in range(3):
    for ray_i in range(7):  # 7 rays
        base = sensor_i * 56 + ray_i * 8
        for tag_i, tag in enumerate(TagLabels):
            SemanticGroups[f"{tag} Detected"].append(base + tag_i)
        SemanticGroups["Misses"].append(base + 6)
        SemanticGroups["Proximity"].append(base + 7)

group_names = list(SemanticGroups.keys())

In [6]:
# === LOAD MODEL ===
session = ort.InferenceSession(ONNX_PATH)
output_names = [o.name for o in session.get_outputs()]
dummy_input = {
    "obs_0": np.zeros((1, 56), dtype=np.float32),
    "obs_1": np.zeros((1, 56), dtype=np.float32),
    "obs_2": np.zeros((1, 56), dtype=np.float32),
    "obs_3": np.zeros((1, 4), dtype=np.float32),
    "action_masks": np.ones((1, NUM_ACTIONS), dtype=np.float32)
}
outputs = session.run(None, dummy_input)
logits_name = [n for n, o in zip(output_names, outputs) if o.shape == (1, NUM_ACTIONS)][0]
print("✅ Using logits output:", logits_name)

✅ Using logits output: /_discrete_distribution/Softmax_output_0


In [7]:
# === PREDICTION FUNCTION ===
def predict_fn(X_batch):
    return session.run([logits_name], {
        "obs_0": X_batch[:, 0:56],
        "obs_1": X_batch[:, 56:112],
        "obs_2": X_batch[:, 112:168],
        "obs_3": X_batch[:, 168:172],
        "action_masks": np.ones((X_batch.shape[0], NUM_ACTIONS), dtype=np.float32)
    })[0]

Y = predict_fn(X)
valid_rows = np.isfinite(Y).all(axis=1)
X_clean = X[valid_rows]
Y_clean = Y[valid_rows]
print(f"✅ Kept {X_clean.shape[0]} clean observations out of {X.shape[0]}")

✅ Kept 12696 clean observations out of 12696


In [None]:
# === LOAD ACTIONS TAKEN ===
actions_taken = np.fromfile(ACTIONS_PATH, dtype=np.int32)
actions_taken = actions_taken.reshape(-1, 2)
#actions taken is currently in the shape (n, 2) but i want to populate another array with the second index of column 2
# so take the second dimension, find index 1 and append it to the new array
actions_taken = actions_taken[:, 1]
print(f"Actions taken shape: {actions_taken.shape}")
print(f"Loaded {actions_taken.shape[0]} actions taken")
if actions_taken.shape[0] != X_clean.shape[0]:
    raise ValueError("Mismatch between number of actions and observations")

[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1]
Actions taken shape: (12696,)
Loaded 12696 actions taken


In [24]:
# === SHAP PER ACTION TAKEN ===
semantic_shap_per_obs = []

for action_index in range(NUM_ACTIONS):
    indices = np.where(actions_taken == action_index)[0]
    if len(indices) == 0:
        continue

    print(f"🎯 Training surrogate for Action {action_index}")
    model = xgb.XGBRegressor(n_estimators=200, max_depth=15)
    model.fit(X_clean, Y_clean[:, action_index])

    explainer = shap.Explainer(model)
    shap_values = explainer(X_clean[indices])

    for obs_i, shap_val in zip(indices, shap_values.values):
        sem_scores = {}
        for group in group_names:
            indices_in_group = SemanticGroups[group]
            sem_scores[group] = float(np.abs(shap_val[indices_in_group]).mean())

        semantic_shap_per_obs.append({
            "observation_index": int(obs_i),
            "action": int(action_index),
            "semantic_shap": sem_scores
        })

🎯 Training surrogate for Action 0
🎯 Training surrogate for Action 1
🎯 Training surrogate for Action 3
🎯 Training surrogate for Action 4


In [25]:
# === SAVE JSON ===
with open(SHAP_JSON_PATH, "w") as f:
    json.dump(semantic_shap_per_obs, f, indent=2)

print(f"✅ Saved semantic SHAP values to {SHAP_JSON_PATH}")

✅ Saved semantic SHAP values to semantic_shap.json
