In [1]:
from src import Trainable
from src.sverl import Shapley
import torch
import os
from src.utils import get_device

In [2]:
model: Shapley = Trainable.load_checkpoint("checkpoints\\breakout-ppo-ps-1\\100.pt")

  logger.warn(
  logger.warn(


None


In [None]:
# Measures validation loss
total_loss = 0
N = 5000
for x, _ in model.state_sampler.sample(N):
    # Applies the mask to the state observation
    mask = torch.rand(x.shape) < 0.5
    model.characteristic.masker.mask(x, mask)
    masked = model.characteristic.masker.masked_like(x)

    # Calculates the relevant characteristic function results 
    part_1 = model.characteristic.infer(x)
    part_2 = model.characteristic.infer(masked)

    # Gets the relevant outputs from the shapley model
    with torch.no_grad():
        results = model.model(x)
        masked_results = mask.unsqueeze(-1).to(model.device) * results
        dim = tuple(range(1, masked_results.dim() - 1))
        part_3 = masked_results.sum(dim=dim)

        # Calculates MSE loss
        total_loss += torch.square(part_1 - part_2 - part_3).sum().item()

# Calculates and logs the average validation loss
total_loss /= N
print(total_loss)

0.2382502492427826


In [3]:
# Load training data
import numpy as np
def load_test_data(folder_path):
    data = []
    for filename in os.listdir(folder_path):
        full_path = os.path.join(folder_path, filename)
        try:
            data.append(torch.load(full_path, weights_only=False))
            print(f"Loaded: {filename}")
        except Exception as e:
            print(f"Failed to load {filename}: {e}")
    return data

folder = "./data/test/breakout"
test_data = load_test_data(folder)

Loaded: 1744679029
Loaded: 1744679030
Loaded: 1744679036
Loaded: 1744679056
Loaded: 1744679057
Loaded: 1744679058
Loaded: 1744679061
Loaded: 1744679064
Loaded: 1744679065
Loaded: 1744679071
Loaded: 1744679072
Loaded: 1744679074
Loaded: 1744679096
Loaded: 1744679100
Loaded: 1744679109
Loaded: 1744679115
Loaded: 1744679121
Loaded: 1744679122
Loaded: 1744679127
Loaded: 1744679130
Loaded: 1744679143
Loaded: 1744679148


In [8]:
import pygame
import cv2
def symmetric_percentile_clip_and_normalize(shapley_vals, p=90):
    """
    Clips and normalizes Shapley values symmetrically based on percentiles.
    
    Args:
        shapley_vals: np.ndarray of shape (H, W, A) – Shapley values for each action.
        lower: Lower percentile (e.g., 1).
        upper: Upper percentile (e.g., 99).

    Returns:
        Normalized Shapley values in range [-1, 1], same shape as input.
    """
    # Compute symmetric percentile clipping threshold
    abs_vals = np.abs(shapley_vals)
    threshold = np.percentile(abs_vals, p)  # same for both sides

    # Clip symmetrically
    clipped = np.clip(shapley_vals, -threshold, threshold)

    # Normalize to [-1, 1]
    normalized = clipped / threshold
    return normalized

def shapley_to_rgba_overlay(shapley_values: torch.Tensor) -> list:
    """
    Converts a (4, 84, 84, 4) tensor of Shapley values into a list of RGBA overlays
    (one for each of the 4 actions), resized to (160, 210).
    """
    shapley_values = shapley_values.squeeze(2)  # (4, 84, 84, 4)
    shapley_per_action = shapley_values.permute(3, 0, 1, 2)  # (action, frames, H, W)
    overlays = []

    for a in range(shapley_per_action.shape[0]):
        values = shapley_per_action[a].mean(dim=0).cpu().numpy()  # (84, 84)
        # norm = np.max(np.abs(values)) + 1e-8  # to avoid division by zero
        
        # values = np.clip(values / norm, -1, 1)
        # values = symmetric_percentile_clip_and_normalize(values)
        # values -= values.mean()
        values /= values.std()
        
        # Create heatmap
        rgba = np.zeros((84, 84, 4), dtype=np.uint8)
        pos_mask = values > 0
        neg_mask = values < 0
        rgba[:, :, 0] = (255 * (values * pos_mask)).astype(np.uint8)   # Red
        rgba[:, :, 2] = (255 * (-values * neg_mask)).astype(np.uint8)  # Blue
        rgba[:, :, 3] = (255 * np.abs(values)).astype(np.uint8)        # Alpha

        # Resize to (160, 210) and transpose to match Pygame format
        # rgba_resized = cv2.resize(rgba, (210, 160), interpolation=cv2.INTER_CUBIC)
        overlays.append(rgba)

    return overlays

In [5]:
def blend_overlay(overlay_rgba: np.ndarray, background_rgb: np.ndarray) -> np.ndarray:
    """
    Blends an RGBA overlay onto an RGB background using alpha compositing.
    
    Args:
        overlay_rgba (np.ndarray): (H, W, 4) RGBA overlay image.
        background_rgb (np.ndarray): (H, W, 3) RGB background image.
        
    Returns:
        np.ndarray: (H, W, 3) blended RGB image.
    """
    # Ensure values are float32 in range [0, 1]
    overlay_rgba = overlay_rgba.astype(np.float32) / 255.0
    background_rgb = background_rgb.astype(np.float32) / 255.0

    # Split RGBA channels
    overlay_rgb = overlay_rgba[..., :3]
    alpha = overlay_rgba[..., 3:]

    # Alpha blend: out = overlay * alpha + background * (1 - alpha)
    blended = overlay_rgb * alpha + background_rgb * (1.0 - alpha)

    # Clip and convert back to uint8
    blended = np.clip(blended * 255, 0, 255).astype(np.uint8)
    return blended

In [12]:
device = get_device()

import pygame

# Constants
WIN_WIDTH = 84 * 5
WIN_HEIGHT = 84 * 5
FRAMERATE = 50

# Pygame Setup
win = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
pygame.display.set_caption("Agent Testing")
clock = pygame.time.Clock()

for test_point in test_data:
    render = test_point['render']
    state = torch.Tensor(test_point['state']).to(device)

    # shapley values: (1, 4, 84, 84, 4)
    
    with torch.no_grad():
        policy_output = torch.softmax(model.target(state), dim=-1)
        masked = model.characteristic.masker.masked_like(state)
        cf_output = model.characteristic.infer(masked)
        contributions = model.model(state)
    
    dim = tuple(range(1, masked.dim() - 1))
    contribution_sum = torch.sum(contributions, dim=dim)
    print(policy_output.argmax())
    norm_factor = (1 / (4 * 84 * 84)) * (policy_output - cf_output - contribution_sum)
    shapley = (contributions + norm_factor).squeeze(0)

    # zero_mask = state == 0
    # print(zero_mask.shape)
    # print(zero_mask.shape)
    # print(zero_mask.squeeze(0).unsqueeze(-1).shape)
    # shapley[zero_mask.squeeze(0), :] = 0

    render_img = render.astype(np.uint8)
    overlays = shapley_to_rgba_overlay(shapley)
    # surfaces = [overlay_surfaces(render_img, overlay) for overlay in overlays]
    # print(overlays[0].shape)
    # pygame.surfarray.make_surface(overlays[0][..., :3])
    surf = pygame.surfarray.make_surface(cv2.resize(render_img, (WIN_HEIGHT, WIN_WIDTH), cv2.INTER_NEAREST))
    surfaces = [
        pygame.surfarray.make_surface(cv2.resize(blend_overlay(cv2.resize(np.swapaxes(img, 0, 1), (210, 160)), render_img), (WIN_HEIGHT, WIN_WIDTH), cv2.INTER_NEAREST))
        for img in overlays
    ]

    border = 0
    surfaces = [
        pygame.surfarray.make_surface(cv2.resize(np.swapaxes(img[..., :3], 0, 1), (WIN_HEIGHT, WIN_WIDTH), cv2.INTER_NEAREST))
        for img in overlays
    ]

    action = 0
    while action < 4:
        dt = clock.tick(FRAMERATE) * 0.001

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                break
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_RETURN:
                    action += 1
        if action >= 4:
            break

        win.blit(surfaces[action], (0, 0))
        pygame.display.update()

pygame.quit()


tensor(0, device='cuda:0')
tensor(3, device='cuda:0')
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(3, device='cuda:0')
tensor(3, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(3, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(1, device='cuda:0')
tensor(1, device='cuda:0')
tensor(1, device='cuda:0')
tensor(1, device='cuda:0')
tensor(1, device='cuda:0')
tensor(3, device='cuda:0')
tensor(0, device='cuda:0')
