# **TabPFN Relational Benchmark**
This notebook benchmarks the performance of TabPFN models on datasets from RelBench in two scenarios:
1. **Single Table** – Using only the target entity table.
2. **Merged Table** – Using a naively denormalized table obtained by joining related tables.

It automates dataset loading, preprocessing (including date feature engineering), vectorization, model training, prediction, and evaluation for all compatible tasks within a chosen RelBench dataset. The results allow comparing model performance between single-table and merged-table configurations.


## Import Libraries

In [None]:
# --- Standard Library ---
import os
import time
import inspect
from typing import Any, Dict, List, Optional

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd

# --- PyTorch and PyTorch Geometric ---
import torch
from torch import Tensor
from torch_frame.config.text_embedder import TextEmbedderConfig

# --- Skrub and Sentence Transformers ---
from skrub import TableVectorizer
from sentence_transformers import SentenceTransformer

# --- RelBench ---
from relbench.datasets import get_dataset
from relbench.tasks import get_task, get_task_names
from relbench.base import TaskType
import relbench.metrics
from relbench.modeling.utils import get_stype_proposal, to_unix_time
from relbench.modeling.graph import make_pkey_fkey_graph

# --- TabPFN ---
from tabpfn import TabPFNClassifier, TabPFNRegressor


### Set Global Configuration

In [None]:
# Device preference
#if torch.backends.mps.is_available():
#    DEVICE = "mps"
#elif torch.cuda.is_available():
#    DEVICE = "cuda"
#else:
#    DEVICE = "cpu"

DEVICE = "cpu" # Force CPU for compatibility with all environments
print(f"Using device: {DEVICE}")

# Define global dataset variable (any available dataset from relbench.datasets)
# "rel-f1" is the default, but can be overridden by setting DATASET variable
DATASET = "rel-stack"

# Global configuration variables with defaults
SEED   = globals().get("SEED", 42)
N_ESTIMATORS = globals().get("N_ESTIMATORS", 16) # number of TabPFN estimators
TABPFN_MAX = globals().get("TABPFN_MAX", 2000)  # hard ceiling for TabPFN

## Notebook Configuration and Dataset Selection


Sets the dataset name (`DATASET`) and download flag (`DOWNLOAD`), then discovers all available tasks for the selected dataset using RelBench’s APIs. Filters tasks to only those compatible with TabPFN (classification and regression).


In [None]:
# Reuse existing config if present, otherwise set defaults
DATASET = globals().get("DATASET", "rel-f1")
DOWNLOAD = globals().get("DOWNLOAD", True)

# Discover tasks and keep only entity-level cls/reg tasks TabPFN can handle
def _is_tabpfn_friendly(task):
    return task.task_type in (
        TaskType.BINARY_CLASSIFICATION,
        TaskType.MULTICLASS_CLASSIFICATION,
        TaskType.MULTILABEL_CLASSIFICATION,
        TaskType.REGRESSION,
    )

_all = get_task_names(DATASET)  # shown in tutorials
TASKS = []
for tname in _all:
    try:
        t = get_task(DATASET, tname, download=DOWNLOAD)
        if _is_tabpfn_friendly(t):
            TASKS.append(tname)
    except Exception as e:
        print(f"[skip] {tname}: {e!s}")

print(f"{DATASET}: {len(TASKS)} TabPFN-friendly tasks -> {TASKS}")


### Patch RelBench Metrics (Optional)

In [None]:
# Patch relbench.metrics.skm.mean_squared_error to local mean_squared_error
from sklearn.metrics import mean_squared_error

relbench.metrics.skm.mean_squared_error = mean_squared_error

def patched_rmse(true, pred):
    if "squared" in inspect.signature(mean_squared_error).parameters:
        return mean_squared_error(true, pred, squared=False)
    else:
        return np.sqrt(mean_squared_error(true, pred))

relbench.metrics.rmse = patched_rmse

### Text Embedding Configuration

In [None]:
# Define a text embedding class using GloVe embeddings from Sentence Transformers.
class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

## Helper Functions for Dataset Loading and Table Processing


### Date Feature Engineering

Processes all tables in the dataset to detect and parse date columns, replacing missing values and generating engineered date-related features (e.g., year, month, weekday, cyclical encodings).


In [None]:
# Load the dataset and its database
dataset = get_dataset(DATASET)
db = dataset.get_db()

# Helper function to convert any table-like object to a pandas DataFrame
def to_pandas(table):
    if hasattr(table, "to_pandas"):
        return table.to_pandas()
    if hasattr(table, "df"):
        return table.df
    raise ValueError("Unknown table type")

# Convert all tables to pandas DataFrames
tables = {name: to_pandas(tbl) for name, tbl in db.table_dict.items()}

# Date feature engineering
# This will modify the tables in-place, adding new date-related features.
for name, df in tables.items():
    # Drop duplicate columns and force a copy
    df = df.loc[:, ~df.columns.duplicated()].copy()
    date_cols = [col for col in df.columns if "date" in col.lower()]
    if not date_cols:
        tables[name] = df
        continue

    # Compute cleaned dates and all features in memory
    dt_clean = {}
    feats = {}
    for col in date_cols:
        dt = pd.to_datetime(df[col], errors="coerce", utc=True)
        dt_filled = dt.fillna(dt.min())
        dt_clean[col] = dt_filled.dt.tz_localize(None)
        feats.update({
            f"{col}_year":            dt_filled.dt.year,
            f"{col}_month":           dt_filled.dt.month,
            f"{col}_day":             dt_filled.dt.day,
            f"{col}_weekday":         dt_filled.dt.weekday,
            f"{col}_quarter":         dt_filled.dt.quarter,
            f"{col}_is_month_start":  dt_filled.dt.is_month_start.astype(int),
            f"{col}_is_month_end":    dt_filled.dt.is_month_end.astype(int),
            f"{col}_is_weekend":      (dt_filled.dt.weekday >= 5).astype(int),
            f"{col}_month_sin":       np.sin(2 * np.pi * dt_filled.dt.month / 12),
            f"{col}_month_cos":       np.cos(2 * np.pi * dt_filled.dt.month / 12),
            f"{col}_weekday_sin":     np.sin(2 * np.pi * dt_filled.dt.weekday / 7),
            f"{col}_weekday_cos":     np.cos(2 * np.pi * dt_filled.dt.weekday / 7),
            f"{col}_elapsed_days":    (dt_filled - dt_filled.min()).dt.days,
        })

    # Drop original date columns and concatenate new ones in a single op
    df = df.drop(columns=date_cols)
    new_cols_df = pd.DataFrame({**dt_clean, **feats}, index=df.index)
    df = pd.concat([df, new_cols_df], axis=1)

    tables[name] = df

# Update the database with modified DataFrames
for name, df in tables.items():
    db.table_dict[name].df = df

# Patch infer_series_stype to handle "truth value of a Series" errors without recursion
import torch_frame.utils.infer_stype as ts

# Wrap only once by marking the safe wrapper
if not getattr(ts.infer_series_stype, "__safe_wrapped__", False):
    original_infer = ts.infer_series_stype

    def safe_infer_series_stype(ser):
        try:
            return original_infer(ser)
        except ValueError as e:
            if "truth value of a Series" in str(e):
                return original_infer(ser.dropna())
            raise

    safe_infer_series_stype.__safe_wrapped__ = True
    ts.infer_series_stype = safe_infer_series_stype

### Flatten One-Hop Foreign Key Relationships

This function takes a heterogeneous data structure (tensor frames) and a database object, then flattens the one-hop foreign key relationships starting from a specified target table. It performs left joins to denormalize the data into a single pandas DataFrame.

In [None]:
def flatten_one_hop(hetero_data, db, target_table, cutoff_times):
    # base table as pandas
    base = to_pandas(db.table_dict[target_table])
    dfs = []
    # for each row, join only past neighbors
    for idx, cutoff in enumerate(cutoff_times):
        row = base.iloc[[idx]]
        # For each foreign key in the target table, find the corresponding primary key table and join
        for fkey, pkey_table in db.table_dict[target_table].fkey_col_to_pkey_table.items():
            nbr = to_pandas(db.table_dict[pkey_table])
            time_col = db.table_dict[pkey_table].time_col
            # If the neighbor table has a time column, filter by cutoff
            if time_col is not None:
                nbr = nbr[to_unix_time(nbr[time_col]) <= cutoff]
            # Perform left join on the foreign key to primary key relationship
            # We drop the primary key column from the neighbor table to avoid duplication
            row = row.merge(
                nbr.drop(columns=[db.table_dict[pkey_table].pkey_col]),
                left_on=fkey, right_on=db.table_dict[pkey_table].pkey_col,
                how="left", suffixes=("", f"_{pkey_table}")
            )
        dfs.append(row)
    return pd.concat(dfs, ignore_index=True)

### Fetch Dataset Splits

Utility functions to load a task’s splits (`train`, `val`, `test`), convert them to pandas DataFrames, and extract features (`X`) and targets (`y`). Includes functions to build:
- **Single-table frames** directly from the target entity table.
- **Merged-table frames** by performing a one-hop foreign key → primary key join to denormalize data.


In [None]:
# Fetches the task and its splits from the dataset, returning a task object and a dictionary of DataFrames for each split.
def fetch_splits(dataset_name: str, task_name: str, download: bool = True):
    task = get_task(dataset_name, task_name, download=download)
    # keep original columns (tutorial shows mask_input_cols flag)
    splits = {
        split: task.get_table(split, mask_input_cols=False)
        for split in ("train", "val", "test")
    }
    return task, splits

# Converts a DataFrame to feature matrix (X) and target vector (y) based on the specified target column.
def to_Xy(df: pd.DataFrame, target_col: str):
    y = df[target_col].to_numpy()
    X = df.drop(columns=[target_col])
    return X, y

# Infers the primary key column from a table object and its DataFrame. It checks for common attribute names first, then falls back to finding the first unique column.
def _infer_pk(table_obj, df: pd.DataFrame):
    # best-effort: check common attribute names first, then infer by uniqueness
    for attr in ("primary_key_col", "pkey", "pk", "primary_key", "id_col"):
        if hasattr(table_obj, attr):
            cand = getattr(table_obj, attr)
            if isinstance(cand, str) and cand in df.columns:
                return cand
    # fallback: take the first unique column if exists
    for c in df.columns:
        try:
            if df[c].is_unique:
                return c
        except Exception:
            pass
    return None

# Builds frames for each split of a task, either using the original target table (single-table mode) or by denormalizing the data (merged-table mode). Returns a dictionary with split names as keys and tuples of (X, y, original_df) as values.
def build_single_table_frames(task, splits):
    frames = {}
    for split, table in splits.items():
        df = table.df.copy()
        # IMPORTANT: do not touch column masking/order; just drop the target to form X
        X, y = to_Xy(df, task.target_col)
        frames[split] = (X, y, df)  # keep original df for evaluation alignment
    return frames

# Builds frames for each split of a dataset in merged-table mode by performing a one-hop foreign key → primary key join. It flattens the data structure, extracts features and targets, and returns a dictionary with split names as keys and tuples of (X, y, merged_df) as values.
def build_merged_table_frames(hetero_data, db, task, splits):
    out = {}
    tbl = task.entity_table
    tcol = db.table_dict[tbl].time_col

    for name, table in splits.items():
        # get pandas DataFrame from Table object
        df = table.df if hasattr(table, "df") else to_pandas(table)
        # compute cutoff timestamps
        cutoff = to_unix_time(df[tcol]) if tcol else [None] * len(df)
        # flatten via one-hop join
        merged = flatten_one_hop(hetero_data, db, tbl, cutoff)
        # split into features/target
        X, y = to_Xy(merged, task.target_col)
        out[name] = (X, y, merged)
    return out


## Vectorization Wrapper (Version-Safe)

Initializes a `TableVectorizer` with only supported arguments for the installed `skrub` or `dirty_cat` version, ensuring compatibility. Transforms `train`, `val`, and `test` splits into numerical feature matrices, converting them to dense format if necessary.


### Helper Functions for Vectorization

In [None]:
# This function creates a `TableVectorizer` instance with version-specific arguments.
def _make_table_vectorizer():
    sig = inspect.signature(TableVectorizer.__init__)
    allowed = set(sig.parameters.keys()) - {"self"}

    tv_kwargs = {}

    # Only set kwargs that actually exist in the installed version
    if "cardinality_threshold" in allowed:
        tv_kwargs["cardinality_threshold"] = globals().get("CARDINALITY_THRESHOLD", 1000)

    # Some versions expose this; others don't, guard it
    if "high_cardinality_transformer" in allowed:
        tv_kwargs["high_cardinality_transformer"] = globals().get("HIGH_CARD_TRANSFORMER", "hashing")

    # Optional knobs if you define them globally and the version supports them
    if "text_separator" in allowed and "TEXT_SEPARATOR" in globals():
        tv_kwargs["text_separator"] = globals()["TEXT_SEPARATOR"]
    if "numerical_transformer" in allowed and "NUMERICAL_TRANSFORMER" in globals():
        tv_kwargs["numerical_transformer"] = globals()["NUMERICAL_TRANSFORMER"]
    if "categorical_transformer" in allowed and "CATEGORICAL_TRANSFORMER" in globals():
        tv_kwargs["categorical_transformer"] = globals()["CATEGORICAL_TRANSFORMER"]

    return TableVectorizer(**tv_kwargs)

# Converts a sparse matrix to a dense NumPy array, handling cases where the input is already dense or does not support `.toarray()`.
def _to_dense(X):
    try:
        # scipy sparse matrices have .toarray()
        return X.toarray() if hasattr(X, "toarray") else X
    except Exception:
        return X

# Vectorizes the training, validation, and test splits using a `TableVectorizer`. It initializes the vectorizer, fits it on the training data, and transforms all splits into dense NumPy arrays. Returns the vectorizer and the transformed matrices.
def vectorize_splits(X_train, X_val, X_test):
    # Fit only on training data to prevent data leakage
    tv = _make_table_vectorizer()
    Xt = _to_dense(tv.fit_transform(X_train))
    Xv = _to_dense(tv.transform(X_val))
    Xs = _to_dense(tv.transform(X_test))
    return tv, Xt, Xv, Xs


In [None]:
# This function subsamples the training data to a maximum size defined by `TABPFN_MAX`. If the dataset is smaller than this cap, it returns the full dataset; otherwise, it randomly samples without replacement.
def _subsample(X, y, cap=TABPFN_MAX, seed=SEED):
    if len(X) <= cap:
        return X, y, np.arange(len(X))
    idx = np.random.RandomState(seed).choice(len(X), size=cap, replace=False)
    if hasattr(X, "iloc"):
        Xs = X.iloc[idx]
    else:
        Xs = X[idx]
    ys = y[idx]
    return Xs, ys, idx

# Fits a TabPFN model (either classifier or regressor) based on the task type. It initializes the model with the specified device and number of estimators, then fits it to the provided training data.
def _fit_tabpfn(task, Xt, yt):
    if task.task_type == TaskType.REGRESSION and TabPFNRegressor is not None:
        model = TabPFNRegressor(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    else:
        model = TabPFNClassifier(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    model.fit(Xt, yt)
    return model

# Helper function to make predictions for a given task using the fitted model. It handles different task types (regression, binary classification, multiclass/multilabel) and returns the appropriate prediction format.
def _predict_for_task(task, model, X):
    # align with RelBench evaluators: AUROC expects probabilities for the positive class
    if task.task_type == TaskType.REGRESSION:
        return model.predict(X)
    proba = model.predict_proba(X)
    if task.task_type == TaskType.BINARY_CLASSIFICATION:
        return proba[:, 1]
    else:
        # multiclass/multilabel: pass full probability matrix
        return proba

### Build a hetero-temporal graph for experiments


In [None]:
col_to_stype_dict = get_stype_proposal(db)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=torch.device(DEVICE)), batch_size=256
)
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=None,
)

### Run TabPFN on Selected Tasks

Runs TabPFN on a specified dataset and task, handling both single-table and merged-table modes. It vectorizes the data, fits the model, makes predictions, and evaluates performance using RelBench’s evaluators. Returns a dictionary with results.

In [None]:
def run_tabpfn_on_task(dataset_name: str, task_name: str, mode: str = "single") -> Dict[str, Any]:
    # Load dataset and task splits
    input_db = get_dataset(dataset_name, download=DOWNLOAD).get_db()
    task, splits = fetch_splits(dataset_name, task_name, download=DOWNLOAD)

    # Ensure the task is compatible with TabPFN
    if mode == "single":
        frames = build_single_table_frames(task, splits)
    elif mode == "merged":
        frames = build_merged_table_frames(data, input_db, task, splits)
    else:
        raise ValueError("mode must be 'single' or 'merged'")

    # Extract features and targets for each split
    (Xtr, ytr, _dftr) = frames["train"]
    (Xva, yva, dfva)  = frames["val"]
    (Xte, yte, dfte)  = frames["test"]

    # Vectorize
    tv, Xt, Xv, Xs = vectorize_splits(Xtr, Xva, Xte)

    # Respect TabPFN's sample cap
    Xt_cap, yt_cap, _ = _subsample(Xt, ytr, cap=TABPFN_MAX, seed=SEED)

    # Fit
    model = _fit_tabpfn(task, Xt_cap, yt_cap)

    # Predict & Evaluate with RelBench evaluators
    val_pred  = _predict_for_task(task, model, Xv)
    test_pred = _predict_for_task(task, model, Xs)

    # Align predictions with original DataFrame indices for evaluation
    val_metrics  = task.evaluate(val_pred,  splits["val"])
    test_metrics = task.evaluate(test_pred, splits["test"])

    # Convert metrics to a dictionary, ensuring all values are floats
    out = {
        "dataset": dataset_name,
        "task": task_name,
        "mode": mode,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "n_train_used": len(Xt_cap),
        "n_train_total": len(Xt),
        "n_val": len(Xv),
        "n_test": len(Xs),
    }
    return out

## Orchestrator for Benchmark Runs

Iterates over all discovered tasks and runs TabPFN in both **single** and **merged** modes. Collects performance metrics for validation and test splits into a results table, handling failures gracefully. Sorts results for easier comparison.


In [None]:
MODES = globals().get("MODES", ["single", "merged"])

records = []
failures = []

# Run TabPFN on all tasks in both modes and collect results
for task_name in TASKS:
    for mode in MODES:
        try:
            res = run_tabpfn_on_task(DATASET, task_name, mode=mode)
            # Flatten metrics for val and test, one metric per row
            for split in ["val", "test"]:
                metrics = res.get(f"{split}_metrics") or {}
                for metric_name, metric_value in metrics.items():
                    # Only add rows with non-empty metric_value
                    if metric_value is not None and not (isinstance(metric_value, float) and np.isnan(metric_value)):
                        records.append({
                            "dataset": res.get("dataset", DATASET),
                            "task": res.get("task", task_name),
                            "split": split,
                            "mode": res.get("mode", mode),
                            "method": "TabPFN_experimental_v1.0",
                            "metric": metric_name,
                            "score": metric_value,
                        })
        except Exception as e:
            msg = f"[{DATASET} | {task_name} | {mode}] failed: {e!s}"
            print(msg)
            failures.append(msg)

# Convert collected records into a DataFrame
results_df = pd.DataFrame.from_records(records)

# If no successful runs were recorded, display a message and create an empty DataFrame
if results_df.empty:
    print("No successful runs were recorded. Check the failure messages above.")
    results_df = pd.DataFrame(
        columns=["dataset", "task", "split", "mode", "method", "metric", "score"]
    )
else:
    # Ensure required sort keys exist even if some rows missed them
    for col in ["task", "mode"]:
        if col not in results_df.columns:
            results_df[col] = pd.NA
    # Sort only by the columns that exist to avoid KeyError
    sort_keys = [c for c in ["task", "mode", "metric"] if c in results_df.columns]
    if sort_keys:
        results_df = results_df.sort_values(sort_keys)

display(results_df)


### Save Results to CSV

In [None]:
# Specify output directory
out_dir = globals().get("OUT_DIR", "outputs")
os.makedirs(out_dir, exist_ok=True)

# Change timestamp format to "dd.mm.yyyy-hh:mm"
timestamp = time.strftime("%d.%m.%Y-%H:%M")
csv_name = f"tabpfn_{DATASET}_{timestamp}.csv"
csv_path = os.path.join(out_dir, csv_name)

# Round all numerical results to 4 decimal places before saving
if "score" in results_df.columns:
    if pd.api.types.is_numeric_dtype(results_df["score"]):
        results_df["score"] = results_df["score"].round(4)
    else:
        # Optionally, try to convert to numeric first
        results_df["score"] = pd.to_numeric(results_df["score"], errors="coerce").round(4)

# Filter out rows with empty score values before saving
if "score" in results_df.columns:
    results_df = results_df[results_df["score"].notnull()]

# Save the results DataFrame to a CSV file
results_df.to_csv(csv_path, index=False)
print(f"Saved results to: {csv_path}")
