In [1]:
import onnx
import torch
import numpy as np
import matplotlib.pyplot as plt
from onnx2pytorch import ConvertModel
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# === Load ONNX model ===
onnx_model_path = "pruned_model.onnx"  # <-- Replace with your path
onnx_model = onnx.load(onnx_model_path)

# === Convert to PyTorch ===
pytorch_model = ConvertModel(onnx_model)
pytorch_model.eval()

  layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))


ConvertModel(
  (Conv_/network_body/observation_encoder/processors.0/conv_layers/conv_layers.0/Conv_output_0): Conv2d(5, 16, kernel_size=(8, 8), stride=(4, 4))
  (LeakyRelu_/network_body/observation_encoder/processors.0/conv_layers/conv_layers.1/LeakyRelu_output_0): LeakyReLU(negative_slope=0.009999999776482582, inplace=True)
  (Conv_/network_body/observation_encoder/processors.0/conv_layers/conv_layers.2/Conv_output_0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
  (LeakyRelu_/network_body/observation_encoder/processors.0/conv_layers/conv_layers.3/LeakyRelu_output_0): LeakyReLU(negative_slope=0.009999999776482582, inplace=True)
  (Constant_/network_body/observation_encoder/processors.0/Constant_output_0): Constant(constant=tensor([ -1, 288]))
  (Reshape_/network_body/observation_encoder/processors.0/Reshape_output_0): Reshape(shape=None)
  (Gemm_/network_body/observation_encoder/processors.0/dense/dense.0/Gemm_output_0): Linear(in_features=288, out_features=256, bias=True)
  (Le

In [3]:
# === Load and prepare a sample observation ===
# === Load raw binary file ===
bin_path = "foodAgentObs.bin"  # replace with your path
flat_obs = np.fromfile(bin_path, dtype=np.float32)

# Each observation is 8000 floats = 40 x 40 x 5
obs_count = flat_obs.shape[0] // 8000
obs_array = flat_obs.reshape((obs_count, 40, 40, 5))  # NHWC (Unity-style)
action_masks = torch.tensor([[1.0, 1.0]], dtype=torch.float32)

# Convert to PyTorch-style NCHW
obs_array = np.transpose(obs_array, (0, 3, 1, 2))  # → (N, 5, 40, 40)

saliency_shoot = np.zeros((5, 40, 40))
saliency_noshoot = np.zeros((5, 40, 40))
saliency_forward = np.zeros((5, 40, 40))
saliency_strafe = np.zeros((5, 40, 40))
saliency_turn = np.zeros((5, 40, 40))
count_shoot = 0
count_noshoot = 0
count_forward = 0
count_strafe = 0
count_turn = 0

In [4]:
for i in tqdm(range(obs_count)):
    shoot_idx = None
    movement_idx = None

    obs = obs_array[i]
    obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
    obs_tensor.requires_grad = True

    inputs = {"obs_0": obs_tensor, "action_masks": action_masks}
    output = pytorch_model(**inputs)

    continuous_logits = output[0]  # continuous
    discrete_logits = output[1]  # discrete
    shoot_idx = discrete_logits.argmax().item()
    movement_idx = continuous_logits.argmax().item()

    discrete_logits[0, shoot_idx].backward(retain_graph=True)
    continuous_logits[0, movement_idx].backward()
    saliency = obs_tensor.grad.abs().squeeze().numpy()

    if shoot_idx != None:
        if shoot_idx == 0:
            saliency_noshoot += saliency
            count_noshoot += 1
        else:
            saliency_shoot += saliency
            count_shoot += 1
    if movement_idx != None:
        if movement_idx == 0:
            saliency_forward += saliency
            count_forward += 1
        elif movement_idx == 1:
            saliency_strafe += saliency
            count_strafe += 1
        elif movement_idx == 2:
            saliency_turn += saliency
            count_turn += 1

    obs_tensor.grad = None 




100%|██████████| 34950/34950 [01:43<00:00, 338.98it/s]


In [5]:
avg_shoot = saliency_shoot / (count_shoot + 1e-8)
avg_noshoot = saliency_noshoot / (count_noshoot + 1e-8)
avg_forward = saliency_forward / (count_forward + 1e-8)
avg_strafe = saliency_strafe / (count_strafe + 1e-8)
avg_turn = saliency_turn / (count_turn + 1e-8)


In [6]:
channels = ['Food','Agent','Wall','Bad Food','Frozen Agent']

for i in range(5):
    fig, axs = plt.subplots(2, 1, figsize=(5, 6))  # One column: No Shoot (top), Shoot (bottom)

    axs[0].imshow(avg_noshoot[i], cmap='Reds', interpolation='nearest')
    axs[0].set_title(f"No Shoot: {channels[i]} Detected")
    axs[0].axis('off')

    axs[1].imshow(avg_shoot[i], cmap='Reds', interpolation='nearest')
    axs[1].set_title(f"Shoot: {channels[i]} Detected")
    axs[1].axis('off')

    plt.savefig(f"Discrete_saliency_comparison_channel_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0.25)
    plt.close()

for i in range(5):
    fig, axs = plt.subplots(3, 1, figsize=(5, 9))  # One column: Forward (top), Strafe (middle), Turn (bottom)

    axs[0].imshow(avg_forward[i], cmap='Reds', interpolation='nearest')
    axs[0].set_title(f"Forward: {channels[i]} Detected")
    axs[0].axis('off')

    axs[1].imshow(avg_strafe[i], cmap='Reds', interpolation='nearest')
    axs[1].set_title(f"Strafe: {channels[i]} Detected")
    axs[1].axis('off')

    axs[2].imshow(avg_turn[i], cmap='Reds', interpolation='nearest')
    axs[2].set_title(f"Turn: {channels[i]} Detected")
    axs[2].axis('off')

    plt.savefig(f"Continuous_saliency_comparison_channel_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0.25)
    plt.close()