In [14]:
# File: DungeonShap_MultiAgent.py
import numpy as np
import shap
import xgboost as xgb
import matplotlib.pyplot as plt
import onnxruntime as ort
import json

In [25]:
# === CONFIG ===
BIN_PATHS = [
    "AgentObservations_Agent0.bin",
    "AgentObservations_Agent1.bin",
    "AgentObservations_Agent2.bin"
]
ACTION_BIN_PATHS = [
    "AgentActions_Agent0.bin",
    "AgentActions_Agent1.bin",
    "AgentActions_Agent2.bin"
]
SHAP_PATH = 'Dungeon_shap_agent'
ONNX_PATH = "Dungeon_with_logits.onnx"
OBS_SIZE = 410
NUM_ACTIONS = 7
RAY_FEATURES = 8
RAYS_PER_SENSOR = 17
NUM_SENSORS = 3
RAY_TOTAL = RAYS_PER_SENSOR * NUM_SENSORS  # 51 rays

In [26]:
# === TAG DEFINITIONS ===
TagLabels = ["Wall", "Agent", "Dragon", "Key", "Lock", "Portal"]
SemanticGroups = {f"{tag}": [] for tag in TagLabels}
SemanticGroups["Misses"] = []
SemanticGroups["Proximity"] = []
SemanticGroups["HasKey"] = [408]
SemanticGroups["DragonDead"] = [409]

for ray_i in range(RAY_TOTAL):
    base = ray_i * RAY_FEATURES
    for tag_i, tag in enumerate(TagLabels):
        SemanticGroups[f"{tag}"].append(base + tag_i)
    SemanticGroups["Misses"].append(base + 6)
    SemanticGroups["Proximity"].append(base + 7)

group_names = list(SemanticGroups.keys())

In [27]:
# === LOAD MODEL ===
session = ort.InferenceSession(ONNX_PATH)
input_names = [i.name for i in session.get_inputs()]
output_names = [o.name for o in session.get_outputs()]

dummy_input = {
    "obs_0": np.zeros((1, 408), dtype=np.float32),
    "obs_1": np.zeros((1, 2), dtype=np.float32),
    "action_masks": np.ones((1, NUM_ACTIONS), dtype=np.float32)
}
outputs = session.run(None, dummy_input)
logits_name = [name for name, out in zip(output_names, outputs) if out.shape == (1, NUM_ACTIONS)][0]
print("✅ Using logits output:", logits_name)

✅ Using logits output: /_discrete_distribution/Softmax_output_0


In [28]:
# === PREDICT FUNCTION ===
def predict_fn(X):
    return session.run([logits_name], {
        "obs_0": X[:, :408],
        "obs_1": X[:, 408:],
        "action_masks": np.ones((X.shape[0], NUM_ACTIONS), dtype=np.float32)
    })[0]

In [30]:
#load all action data and check lengths
action_data = [np.fromfile(path, dtype=np.float32) for path in ACTION_BIN_PATHS]
for i, data in enumerate(action_data):
    print(f"Action data {i} length: {len(data)}")

#load all observation data and check lengths
obs_data = [np.fromfile(path, dtype=np.float32) for path in BIN_PATHS]
for i, data in enumerate(obs_data):
    print(f"Observation data {i} length: {len(data)/410}")

Action data 0 length: 9656
Action data 1 length: 7247
Action data 2 length: 7131
Observation data 0 length: 9656.0
Observation data 1 length: 7247.0
Observation data 2 length: 7131.0


In [31]:

semantic_shap_per_obs = [[],[],[]]

for i, path in enumerate(BIN_PATHS):
    data = np.fromfile(path, dtype=np.float32)
    remainder = data.size % OBS_SIZE
    if remainder != 0:
        print(f"Trimming {remainder} extra floats from {path}")
        data = data[:-remainder]
    X = data.reshape(-1, OBS_SIZE)
    X = X[np.isfinite(X).all(axis=1)]

    print(f"Loaded {X.shape[0]} valid observations from {path}")

    actions_taken = np.fromfile(ACTION_BIN_PATHS[i], dtype=np.int32)
    if actions_taken.size != X.shape[0]:
        print(f'num actions_taken ({actions_taken.size}) does not match num observations ({X.shape[0]})')
        break
    else:
        print(f"Loaded {actions_taken.shape[0]} actions_taken from {ACTION_BIN_PATHS[i]}")


    Y = predict_fn(X)
    Y = Y[np.isfinite(Y).all(axis=1)]
    X = X[:len(Y)]  # Match length

    for action_index in range(NUM_ACTIONS):
        indices = np.where(actions_taken == action_index)[0]
        if len(indices) == 0:
            continue

        print(f"🎯 Training surrogate for Action {action_index}")
        model = xgb.XGBRegressor(n_estimators=200, max_depth=15)
        model.fit(X, Y[:, action_index])

        explainer = shap.Explainer(model)
        shap_values = explainer(X[indices])

        for obs_i, shap_val in zip(indices, shap_values.values):
            sem_scores = {}
            for group in group_names:
                indices_in_group = SemanticGroups[group]
                sem_scores[group] = float(np.abs(shap_val[indices_in_group]).mean())

            semantic_shap_per_obs[i].append({
                "observation_index": int(obs_i),
                "action": int(action_index),
                "semantic_shap": sem_scores
            })

Loaded 9656 valid observations from AgentObservations_Agent0.bin
Loaded 9656 actions_taken from AgentActions_Agent0.bin
🎯 Training surrogate for Action 0
🎯 Training surrogate for Action 1
🎯 Training surrogate for Action 2
🎯 Training surrogate for Action 3
🎯 Training surrogate for Action 4
🎯 Training surrogate for Action 5
🎯 Training surrogate for Action 6
Loaded 7247 valid observations from AgentObservations_Agent1.bin
Loaded 7247 actions_taken from AgentActions_Agent1.bin
🎯 Training surrogate for Action 0
🎯 Training surrogate for Action 1
🎯 Training surrogate for Action 2
🎯 Training surrogate for Action 3
🎯 Training surrogate for Action 4
🎯 Training surrogate for Action 5
🎯 Training surrogate for Action 6
Loaded 7131 valid observations from AgentObservations_Agent2.bin
Loaded 7131 actions_taken from AgentActions_Agent2.bin
🎯 Training surrogate for Action 0
🎯 Training surrogate for Action 1
🎯 Training surrogate for Action 2
🎯 Training surrogate for Action 3
🎯 Training surrogate for Act

In [32]:
# === SAVE JSON ===
for x in range(0, 3):
    # === SAVE JSON ===
    with open(f"{SHAP_PATH}{x}.json", "w") as f:
        json.dump(semantic_shap_per_obs[x], f, indent=2)

    print(f"✅ Saved semantic SHAP values to {SHAP_PATH}")

✅ Saved semantic SHAP values to Dungeon_shap_agent
✅ Saved semantic SHAP values to Dungeon_shap_agent
✅ Saved semantic SHAP values to Dungeon_shap_agent
