In [2]:
from tabpfnwide.patches import compute_attention_heads, _compute
from tabpfn.model.attention.full_attention import MultiHeadAttention
setattr(MultiHeadAttention, "compute_attention_heads", compute_attention_heads)
setattr(MultiHeadAttention, "_compute", _compute)
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from tabpfn.model.loading import load_model_criterion_config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load model

In [3]:

model_name = "TabPFN-Wide-8k" 
assert model_name in ["TabPFN-Wide-1.5k", "TabPFN-Wide-5k", "TabPFN-Wide-8k", "TabPFNv2"], f"Model name {model_name} not recognized."
checkpoint_path = f"./models/{model_name}_submission.pt"
model, _, _ = load_model_criterion_config(
    model_path=None,
    check_bar_distribution_criterion=False,
    cache_trainset_representation=False,
    which='classifier',
    version='v2',
    download=True,
)
if model_name != "TabPFNv2":
    model.features_per_group = 1
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint)

### Generate dataset and add noise features

In [4]:
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=100,
    n_features=3,
    n_informative=3,
    n_redundant=0,
    n_classes=2,
    random_state=42
)

# Widen X with noise
noise = np.random.normal(0, 1, (X.shape[0], 18))
X_new = np.hstack((noise, X))
# Permuute columns randomly
permutation = np.random.permutation(X_new.shape[1])
X_new = X_new[:, permutation]

X, y = torch.tensor(X_new, dtype=torch.float32), torch.tensor(y, dtype=torch.int8)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train_tensor = X_train.unsqueeze(1).to(device)
X_test_tensor = X_test.unsqueeze(1).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.int8).unsqueeze(1).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.int8).unsqueeze(1).to(device)
# Activate attention map saving
for layer in model.transformer_encoder.layers:
    layer.self_attn_between_features.attention_map = None
    layer.self_attn_between_features.save_att_map = True
    layer.self_attn_between_features.number_of_samples = X_train_tensor.shape[0]

with torch.inference_mode():
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        pred_logits = model(
            train_x=X_train_tensor,
            train_y=y_train_tensor,
            test_x=X_test_tensor,
        )
        n_classes = len(np.unique(y_train_tensor.cpu()))
        pred_logits = pred_logits[..., :n_classes].float()  
        pred_probs = torch.softmax(pred_logits, dim=-1)[:, 0, :].detach().cpu().numpy()

### Retrieve attention maps

In [6]:
atts = [getattr(layer.get_submodule("self_attn_between_features"), "attention_map") for layer in model.transformer_encoder.layers]
atts = torch.stack(atts, dim=0)
att_to_last_column = atts.mean(dim=0)[-1, :-1]

In [7]:
print("Indices of the 10 most important features (according to attention to the last column):")
print(np.argsort(att_to_last_column.cpu().numpy())[-10:])
original_feature_indices = np.where(np.isin(permutation, [18, 19, 20]))[0]
print("Indices of the original features in the permuted array:", original_feature_indices)

Indices of the 10 most important features (according to attention to the last column):
[ 1  0 10 14  2 11 17 20  9 18]
Indices of the original features in the permuted array: [ 9 18 20]


### The 3 most important features according to the attention score are the original features.