In [1]:
import torch
from pathlib import Path
import pickle
from minigrid.core.constants import DIR_TO_VEC
import numpy as np

directory = Path("data/model-debug/")

def load_data(name: str) -> dict[str, list[torch.Tensor]]:
    return {
        "train": [
            torch.load(directory / f"{name}_0_0.pt", weights_only=True).cpu(),
            torch.load(directory / f"{name}_0_1.pt", weights_only=True).cpu(),
            torch.load(directory / f"{name}_0_2.pt", weights_only=True).cpu(),
        ],
        "validation": [
            torch.cat([
                torch.load(directory / f"{name}_1_0.pt", weights_only=True).cpu(),
                # torch.load(directory / f"{name}_2_0.pt", weights_only=True),
            ]),
            torch.cat([
                torch.load(directory / f"{name}_1_1.pt", weights_only=True).cpu(),
                # torch.load(directory / f"{name}_2_1.pt", weights_only=True),
            ]),
            torch.cat([
                torch.load(directory / f"{name}_1_2.pt", weights_only=True).cpu(),
                # torch.load(directory / f"{name}_2_2.pt", weights_only=True),
            ]),
        ],
    }

attention_weights = load_data("attention")
eta = load_data("eta")
data: dict[str, list[torch.Tensor]] = {
    "train": pickle.load(open(directory / "data_TRAIN.pkl", "rb")),
    "validation": pickle.load(open(directory / "data_VALIDATION.pkl", "rb")),
}

dataset = 'train'
feature_obs, feature_action = data[dataset][0]
target_obs, target_reward = data[dataset][1]
pred_obs, pred_reward, pred_eta = data[dataset][2]
aw = attention_weights[dataset]
e = eta[dataset]

n_samples = feature_obs.shape[0]
n_heads = aw[0].shape[1]
feature_agent_pos = (feature_obs[..., 3] != 0).nonzero()

grid_width = feature_obs.shape[1]
grid_height = feature_obs.shape[2]

amdgpu.ids: No such file or directory


In [None]:
def print_full_matrix(matrix, threshold=1e-3, width=7, precision=3):
    """
    Prints a PyTorch matrix in full without scientific notation and with values close to zero shown as 0.
    
    Args:
        matrix (torch.Tensor): The PyTorch tensor to print
        threshold (float): Values with absolute value below this threshold will be displayed as 0
        width (int): Width of each cell in characters
        precision (int): Number of decimal places for floating point values
    """
    # Ensure we have a 2D tensor
    if matrix.dim() > 2:
        raise ValueError(f"Expected 2D tensor, got tensor with {matrix.dim()} dimensions")
    elif matrix.dim() == 1:
        matrix = matrix.unsqueeze(0)  # Convert 1D to 2D
    
    # Get tensor shape
    rows, cols = matrix.shape
    
    # Create a numpy array from the tensor for easier formatting
    matrix_np = matrix.detach().cpu().numpy()
    
    # Print header with column indices
    print("    ", end="")
    for j in range(cols):
        print(f"{j:>{width}}", end="")
    print("\n" + "-" * (cols * width + 6))
    
    # Print each row with row index
    for i in range(rows):
        print(f"{i:>3}|", end="")
        for j in range(cols):
            value = matrix_np[i, j]
            
            # Apply threshold to show values close to 0 as 0
            if abs(value) < threshold:
                formatted = "0".center(width)
            else:
                # Format the value without scientific notation
                if isinstance(value, (int, np.integer)):
                    formatted = f"{value:{width}d}"
                else:
                    formatted = f"{value:{width}.{precision}f}"
            
            print(formatted, end="")
        print()  # New line after each row

index = 53
print(f"Sample {index}")

pos = feature_agent_pos[index][1:]
pos_front = (pos + torch.Tensor(DIR_TO_VEC[feature_obs[*feature_agent_pos[index]][3]-1])).long()
print(f"Agent Position: {pos} ({pos[1] * grid_width + pos[0]})")
print(f"Agent Front: {pos_front} ({pos_front[1] * grid_width + pos_front[0]})")

eta_test = torch.eye(grid_width * grid_height + 2, grid_width * grid_height + 2).expand(n_samples, -1, -1)
eta_new = torch.eye(grid_width * grid_height + 2, grid_width * grid_height + 2).expand(n_samples, -1, -1)
for layer_idx in range(len(attention_weights[dataset])):
    print(f"Layer {layer_idx}")
    
    attention_summed = torch.sum(attention_weights[dataset][layer_idx], dim=[1])
    # attention_summed += torch.eye(grid_width * grid_height + 3, grid_width * grid_height + 3).expand(n_samples, -1, -1)
    
    print(f"Attention summed:")
    print_full_matrix(attention_summed[index])
    
    eta_layer = torch.bmm(attention_summed, eta_new)
    eta_new = eta_new + eta_layer
    
    print(f"eta:")
    print(f"Correct: {torch.allclose(eta[dataset][layer_idx][index], eta_new[index])}")
    print_full_matrix(eta_new[index])
    
    # for head_idx in range(len(attention_weights[dataset][layer_idx][0])):
    #     pos_idx = pos[0] * grid_width + pos[1]
    #     pos_front_idx = pos_front[0] * grid_width + pos_front[1]
        
    #     print(f"Attention (head {head_idx}): {attention_weights[dataset][layer_idx][index, head_idx, pos[0], pos[1]]}")
    #     print(f"Attention (head {head_idx}) (front): {attention_weights[dataset][layer_idx][index, head_idx, pos_front[0], pos_front[1]]}")

    print()

Sample 53
Agent Position: tensor([2, 2]) (12)
Agent Front: tensor([2, 3]) (17)
Layer 0
Attention summed:
          0      1      2      3      4      5      6      7      8      9     10     11     12     13     14     15     16     17     18     19     20     21     22     23     24     25     26
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  0|  2.000   0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0     1.000   0      0      0      0     1.000   0   
  1|  2.000   0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0      0     1.000   0      0      0      0     1.000   0   
  2|  2.000   0      0      0      0      0      0      0      0      0      0      0      0      0      0      0    

: 