# TabPFN Evaluation with SHAP Explanations

This notebook demonstrates the usage of the `tabpfn_evaluate` function for evaluating a dataset with TabPFN, outer cross-validation, and SHAP feature importances.

This notebook uses a public dataset to show a minmal working example. Custom datasets with results from GEMSS feature selector should be processed using *tabpfn_evaluate_custom_dataset_results.ipynb*.


In [1]:
# %pip install -e ..

In [None]:
# Imports
from IPython.display import display, Markdown
import os
import json
import pandas as pd
import numpy as np
import pyarrow.parquet as pq
import plotly.express as px
from plotly import io as pio
import ast
from sklearn.utils import shuffle

from gemss.diagnostics.tabpfn_evaluation import tabpfn_evaluate

# Example data
from sklearn.datasets import load_iris, fetch_california_housing

pio.renderers.default = "notebook_connected"  # Ensures plotly plots show in notebooks

In [None]:
# Choose task: 'classification' or 'regression'
task = "regression"

# Load example datasets
if task == "classification":
    data = load_iris()
    X, y = data.data, data.target
    feature_names = data.feature_names
else:
    data = fetch_california_housing()
    X, y = data.data, data.target
    feature_names = data.feature_names

# Shuffle for better cross-validation splits
X, y = shuffle(X, y, random_state=42)

# Subsample heavily for faster demo
n_samples_allowed = 100
if X.shape[0] > n_samples_allowed:
    X = X[:n_samples_allowed]
    y = y[:n_samples_allowed]
n_features_allowed = 5
if X.shape[1] > n_features_allowed:
    X = X[:, :n_features_allowed]
    feature_names = feature_names[:n_features_allowed]

X_df = pd.DataFrame(X, columns=feature_names)
X_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 5 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   MedInc      100 non-null    float64
 1   HouseAge    100 non-null    float64
 2   AveRooms    100 non-null    float64
 3   AveBedrms   100 non-null    float64
 4   Population  100 non-null    float64
dtypes: float64(5)
memory usage: 4.0 KB


## Run TabPFN Evaluation with SHAP

- Outer cross-validation
- Feature scaling (standard, minmax or None)
- SHAP explanations (parameter *'explain'*)
- Prints metrics for each fold

> For large X, SHAP explanations take time. For a quick demo, use a small subset or reduce folds.

In [10]:
results = tabpfn_evaluate(
    X_df,
    y,
    apply_scaling="standard",
    outer_cv_folds=2,
    tabpfn_kwargs=None,
    random_state=42,
    verbose=True,
    explain=True,
)

Fold 1/2
n_samples      50.000
n_features      5.000
r2_score        0.599
adjusted_r2     0.553
MSE             0.361
RMSE            0.601
MAE             0.476
MAPE           30.779
dtype: float64



  0%|          | 0/50 [00:00<?, ?it/s]

Fold 2/2
n_samples      50.000
n_features      5.000
r2_score        0.454
adjusted_r2     0.392
MSE             0.714
RMSE            0.845
MAE             0.628
MAPE           33.804
dtype: float64



  0%|          | 0/50 [00:00<?, ?it/s]

## CV Results: Average Metrics

In [11]:
pd.Series(results["average_scores"])

n_samples      50.0000
n_features      5.0000
r2_score        0.5265
adjusted_r2     0.4725
MSE             0.5375
RMSE            0.7230
MAE             0.5520
MAPE           32.2915
dtype: float64

## Feature Importances (SHAP, Mean Per Fold)
Each dictionary below shows mean absolute SHAP values for features in a CV fold.

In [12]:
for fold, shap_imp in enumerate(results.get("shap_explanations_per_fold", [])):
    display(Markdown(f"Fold {fold+1} SHAP Feature Importances:"))
    display(pd.Series(shap_imp).sort_values(ascending=False))
    print()

Fold 1 SHAP Feature Importances:

MedInc        0.965797
AveRooms      0.586431
HouseAge      0.327719
AveBedrms     0.174135
Population    0.032506
dtype: float64




Fold 2 SHAP Feature Importances:

MedInc        0.545384
HouseAge      0.235936
AveRooms      0.096574
AveBedrms     0.058998
Population    0.013096
dtype: float64


