# **TabPFN Relational Benchmark**
This notebook benchmarks the performance of TabPFN models on datasets from RelBench in two scenarios:
1. **Single Table** – Using only the target entity table.
2. **Merged Table** – Using a naively denormalized table obtained by joining related tables.

It automates dataset loading, preprocessing (including date feature engineering), vectorization, model training, prediction, and evaluation for all compatible tasks within a chosen RelBench dataset. The results allow comparing model performance between single-table and merged-table configurations.


## Import Libraries

In [None]:
# --- Standard Library ---
import os
import time
import inspect
from contextlib import contextmanager
from typing import Any, Dict, List, Optional

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd

from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    f1_score,
    mean_absolute_error,
    mean_squared_error,
)

# --- PyTorch and PyTorch Geometric ---
import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

# --- Torch Frame ---
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.data.stats import StatType

# --- Skrub and Sentence Transformers ---
from skrub import TableVectorizer
from sentence_transformers import SentenceTransformer

# --- RelBench ---
from relbench.datasets import get_dataset
from relbench.tasks import get_task, get_task_names
from relbench.base import TaskType
import relbench.metrics
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder

# --- TabPFN ---
from tabpfn import TabPFNClassifier, TabPFNRegressor


### Set Global Configuration

In [None]:
# Device preference
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")

# Define global dataset variable (any available dataset from relbench.datasets)
# "rel-f1" is the default, but can be overridden by setting DATASET variable
DATASET = "rel-avito"

# Global configuration variables with defaults
SEED   = globals().get("SEED", 42)
N_ESTIMATORS = globals().get("N_ESTIMATORS", 16) # number of TabPFN estimators
TABPFN_MAX = globals().get("TABPFN_MAX", 10000)  # hard ceiling for TabPFN

## Notebook Configuration and Dataset Selection


Sets the dataset name (`DATASET`) and download flag (`DOWNLOAD`), then discovers all available tasks for the selected dataset using RelBench’s APIs. Filters tasks to only those compatible with TabPFN (classification and regression).


In [None]:
# Reuse existing config if present, otherwise set defaults
DATASET = globals().get("DATASET", "rel-f1")
DOWNLOAD = globals().get("DOWNLOAD", True)

# Discover tasks and keep only entity-level cls/reg tasks TabPFN can handle
def _is_tabpfn_friendly(task):
    return task.task_type in (
        TaskType.BINARY_CLASSIFICATION,
        TaskType.MULTICLASS_CLASSIFICATION,
        TaskType.MULTILABEL_CLASSIFICATION,
        TaskType.REGRESSION,
    )

_all = get_task_names(DATASET)  # shown in tutorials
TASKS = []
for tname in _all:
    try:
        t = get_task(DATASET, tname, download=DOWNLOAD)
        if _is_tabpfn_friendly(t):
            TASKS.append(tname)
    except Exception as e:
        print(f"[skip] {tname}: {e!s}")

print(f"{DATASET}: {len(TASKS)} TabPFN-friendly tasks -> {TASKS}")


### Patch RelBench Metrics (Optional)

In [None]:
# Patch relbench.metrics.skm.mean_squared_error to local mean_squared_error
relbench.metrics.skm.mean_squared_error = mean_squared_error

def patched_rmse(true, pred):
    if "squared" in inspect.signature(mean_squared_error).parameters:
        return mean_squared_error(true, pred, squared=False)
    else:
        return np.sqrt(mean_squared_error(true, pred))

relbench.metrics.rmse = patched_rmse

### Text Embedding Configuration

In [None]:
class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

### Utility Functions for Metrics and Timing


In [None]:
@contextmanager
def elapsed_timer():
    start = time.perf_counter()
    yield lambda: time.perf_counter() - start

def classification_metrics(y_true, y_pred, y_prob=None) -> Dict[str, float]:
    out = {
        "accuracy": accuracy_score(y_true, y_pred),
        "f1_macro": f1_score(y_true, y_pred, average="macro"),
        "roc_auc": np.nan,
    }
    if y_prob is not None:
        try:
            out["roc_auc"] = roc_auc_score(y_true, y_prob)
        except Exception:
            out["roc_auc"] = np.nan
    else:
        out["roc_auc"] = np.nan
    return out

def regression_metrics(y_true, y_pred, y_prob=None) -> Dict[str, float]:
    # Accepts y_prob for compatibility, but ignores it
    return {
        "mae": mean_absolute_error(y_true, y_pred),
        "mse": mean_squared_error(y_true, y_pred),
    }

## Helper Functions for Dataset Loading and Table Processing


### Date Feature Engineering and Optional Graph Build

Processes all tables in the dataset to detect and parse date columns, replacing missing values and generating engineered date-related features (e.g., year, month, weekday, cyclical encodings).
Also includes a step to construct a hetero-temporal graph for GNN experiments using text embeddings.


In [None]:
dataset = get_dataset(DATASET)
db = dataset.get_db()

def to_pandas(table):
    if hasattr(table, "to_pandas"):
        return table.to_pandas()
    if hasattr(table, "df"):
        return table.df
    raise ValueError("Unknown table type")

# Convert all tables to pandas DataFrames
tables = {name: to_pandas(tbl) for name, tbl in db.table_dict.items()}

# Batch‐collect date features and concat at once
for name, df in tables.items():
    date_cols = [col for col in df.columns if "date" in col.lower()]
    if not date_cols:
        continue
    feats_list = []
    for col in date_cols:
        dt = pd.to_datetime(df[col], errors="coerce", utc=True)
        dt_filled = dt.fillna(dt.min())
        df[col] = dt_filled.dt.tz_localize(None)
        feats = pd.DataFrame({
            f"{col}_year": dt_filled.dt.year,
            f"{col}_month": dt_filled.dt.month,
            f"{col}_day": dt_filled.dt.day,
            f"{col}_weekday": dt_filled.dt.weekday,
            f"{col}_quarter": dt_filled.dt.quarter,
            f"{col}_is_month_start": dt_filled.dt.is_month_start.astype(int),
            f"{col}_is_month_end": dt_filled.dt.is_month_end.astype(int),
            f"{col}_is_weekend": (dt_filled.dt.weekday >= 5).astype(int),
            f"{col}_month_sin": np.sin(2 * np.pi * dt_filled.dt.month / 12),
            f"{col}_month_cos": np.cos(2 * np.pi * dt_filled.dt.month / 12),
            f"{col}_weekday_sin": np.sin(2 * np.pi * dt_filled.dt.weekday / 7),
            f"{col}_weekday_cos": np.cos(2 * np.pi * dt_filled.dt.weekday / 7),
            f"{col}_elapsed_days": (dt_filled - dt_filled.min()).dt.days,
        }, index=df.index)
        feats_list.append(feats)
    if feats_list:
        df = pd.concat([df] + feats_list, axis=1)
    tables[name] = df

# --- ADD: push processed tables back into db.table_dict ---
for name, df in tables.items():
    db.table_dict[name].df = df
# --- END ADD ---

# --- ADD THIS BLOCK: Build the hetero-temporal graph for GNN experiments ---
col_to_stype_dict = get_stype_proposal(db)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=torch.device(DEVICE)), batch_size=256
)
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=None,
)

### Fetch Dataset Splits

Utility functions to load a task’s splits (`train`, `val`, `test`), convert them to pandas DataFrames, and extract features (`X`) and targets (`y`). Includes functions to build:
- **Single-table frames** directly from the target entity table.
- **Merged-table frames** by performing a one-hop foreign key → primary key join to denormalize data.


In [None]:
def fetch_splits(dataset_name: str, task_name: str, download: bool = True):
    task = get_task(dataset_name, task_name, download=download)
    # keep original columns (tutorial shows mask_input_cols flag)
    splits = {
        split: task.get_table(split, mask_input_cols=False)
        for split in ("train", "val", "test")
    }
    return task, splits

def to_Xy(df: pd.DataFrame, target_col: str):
    y = df[target_col].to_numpy()
    X = df.drop(columns=[target_col])
    return X, y

# --- Naive merged table via one-hop FK→PK joins using dataset.get_db() ---
# Uses only duck-typing on Table objects; falls back gracefully if metadata is missing.
def _infer_pk(table_obj, df: pd.DataFrame):
    # best-effort: check common attribute names first, then infer by uniqueness
    for attr in ("primary_key_col", "pkey", "pk", "primary_key", "id_col"):
        if hasattr(table_obj, attr):
            cand = getattr(table_obj, attr)
            if isinstance(cand, str) and cand in df.columns:
                return cand
    # fallback: take the first unique column if exists
    for c in df.columns:
        try:
            if df[c].is_unique:
                return c
        except Exception:
            pass
    return None

def denormalize_one_hop(dataset, base_df: pd.DataFrame):
    try:
        db = dataset.get_db()  # documented in README
    except Exception:
        return base_df  # cannot access DB; return base table

    tables = getattr(db, "tables", None)
    if not isinstance(tables, dict):
        return base_df

    # Build a map: PK column name -> (table_name, table_df, pk_col)
    pkmap = {}
    for name, tbl in tables.items():
        df = getattr(tbl, "df", None)
        if isinstance(df, pd.DataFrame):
            pk = _infer_pk(tbl, df)
            if pk and pk in df.columns:
                pkmap.setdefault(pk, []).append((name, df, pk))

    df_out = base_df.copy()
    # For each column in base that matches a PK in the DB, left join the non-key attributes
    for fk_col in list(df_out.columns):
        if fk_col in pkmap:
            for (tname, tdf, pk) in pkmap[fk_col]:
                right = tdf.drop(columns=[pk], errors="ignore").copy()
                if right.empty:
                    continue
                # prefix joined columns to avoid collisions
                right = right.add_prefix(f"{tname}__")
                # attach the join key back (unprefixed) for merge
                right[fk_col] = tdf[pk]
                try:
                    df_out = df_out.merge(right, on=fk_col, how="left")
                except Exception:
                    # keep going; some keys may be non-joinable due to dtype issues
                    pass
    return df_out

def build_single_table_frames(task, splits):
    frames = {}
    for split, table in splits.items():
        df = table.df.copy()
        # IMPORTANT: do not touch column masking/order; just drop the target to form X
        X, y = to_Xy(df, task.target_col)
        frames[split] = (X, y, df)  # keep original df for evaluation alignment
    return frames

def build_merged_table_frames(dataset, task, splits):
    frames = {}
    for split, table in splits.items():
        base_df = table.df.copy()
        merged_df = denormalize_one_hop(dataset, base_df)
        X, y = to_Xy(merged_df, task.target_col)
        frames[split] = (X, y, merged_df)
    return frames


## Vectorization Wrapper (Version-Safe)

Initializes a `TableVectorizer` with only supported arguments for the installed `skrub` or `dirty_cat` version, ensuring compatibility. Transforms `train`, `val`, and `test` splits into numerical feature matrices, converting them to dense format if necessary.


### Helper Functions for Vectorization

In [None]:
# --- Vectorization wrapper (version-safe for skrub/dirty_cat) ---
# Place this where your previous TableVectorizer/vectorize_splits block was.

def _make_table_vectorizer():
    sig = inspect.signature(TableVectorizer.__init__)
    allowed = set(sig.parameters.keys()) - {"self"}

    tv_kwargs = {}

    # Only set kwargs that actually exist in the installed version
    if "cardinality_threshold" in allowed:
        tv_kwargs["cardinality_threshold"] = globals().get("CARDINALITY_THRESHOLD", 1000)

    # Some versions expose this; others don't — guard it
    if "high_cardinality_transformer" in allowed:
        tv_kwargs["high_cardinality_transformer"] = globals().get("HIGH_CARD_TRANSFORMER", "hashing")

    # Optional knobs if you happen to define them globally and the version supports them
    if "text_separator" in allowed and "TEXT_SEPARATOR" in globals():
        tv_kwargs["text_separator"] = globals()["TEXT_SEPARATOR"]
    if "numerical_transformer" in allowed and "NUMERICAL_TRANSFORMER" in globals():
        tv_kwargs["numerical_transformer"] = globals()["NUMERICAL_TRANSFORMER"]
    if "categorical_transformer" in allowed and "CATEGORICAL_TRANSFORMER" in globals():
        tv_kwargs["categorical_transformer"] = globals()["CATEGORICAL_TRANSFORMER"]

    return TableVectorizer(**tv_kwargs)

def _to_dense(X):
    try:
        # scipy sparse matrices have .toarray()
        return X.toarray() if hasattr(X, "toarray") else X
    except Exception:
        return X

def vectorize_splits(X_train, X_val, X_test):
    tv = _make_table_vectorizer()
    Xt = _to_dense(tv.fit_transform(X_train))
    Xv = _to_dense(tv.transform(X_val))
    Xs = _to_dense(tv.transform(X_test))
    return tv, Xt, Xv, Xs


In [None]:
def _subsample(X, y, cap=TABPFN_MAX, seed=SEED):
    if len(X) <= cap:
        return X, y, np.arange(len(X))
    idx = np.random.RandomState(seed).choice(len(X), size=cap, replace=False)
    if hasattr(X, "iloc"):
        Xs = X.iloc[idx]
    else:
        Xs = X[idx]
    ys = y[idx]
    return Xs, ys, idx

def _fit_tabpfn(task, Xt, yt):
    if task.task_type == TaskType.REGRESSION and TabPFNRegressor is not None:
        model = TabPFNRegressor(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
        )
    else:
        model = TabPFNClassifier(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
        )
    model.fit(Xt, yt)
    return model

def _predict_for_task(task, model, X):
    # align with RelBench evaluators: AUROC expects probabilities for the positive class
    if task.task_type == TaskType.REGRESSION:
        return model.predict(X)
    proba = model.predict_proba(X)
    if task.task_type == TaskType.BINARY_CLASSIFICATION:
        return proba[:, 1]
    else:
        # multiclass/multilabel: pass full probability matrix
        return proba

### Run TabPFN on Selected Tasks

Runs TabPFN on a specified dataset and task, handling both single-table and merged-table modes. It vectorizes the data, fits the model, makes predictions, and evaluates performance using RelBench’s evaluators. Returns a dictionary with results.

In [None]:
def run_tabpfn_on_task(dataset_name: str, task_name: str, mode: str = "single") -> Dict[str, Any]:
    dataset = get_dataset(dataset_name, download=DOWNLOAD)
    task, splits = fetch_splits(dataset_name, task_name, download=DOWNLOAD)

    if mode == "single":
        frames = build_single_table_frames(task, splits)
    elif mode == "merged":
        frames = build_merged_table_frames(dataset, task, splits)
    else:
        raise ValueError("mode must be 'single' or 'merged'")

    (Xtr, ytr, _dftr) = frames["train"]
    (Xva, yva, dfva)  = frames["val"]
    (Xte, yte, dfte)  = frames["test"]

    # Vectorize
    tv, Xt, Xv, Xs = vectorize_splits(Xtr, Xva, Xte)

    # Respect TabPFN's sample cap
    Xt_cap, yt_cap, _ = _subsample(Xt, ytr, cap=TABPFN_MAX, seed=SEED)

    # Fit
    model = _fit_tabpfn(task, Xt_cap, yt_cap)

    # Predict & Evaluate with RelBench evaluators
    val_pred  = _predict_for_task(task, model, Xv)
    test_pred = _predict_for_task(task, model, Xs)

    val_metrics  = task.evaluate(val_pred,  splits["val"])   # documented API
    test_metrics = task.evaluate(test_pred, splits["test"])  # documented API

    out = {
        "dataset": dataset_name,
        "task": task_name,
        "mode": mode,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "n_train_used": len(Xt_cap),
        "n_train_total": len(Xt),
        "n_val": len(Xv),
        "n_test": len(Xs),
    }
    return out

## Orchestrator for Benchmark Runs

Iterates over all discovered tasks and runs TabPFN in both **single** and **merged** modes. Collects performance metrics for validation and test splits into a results table, handling failures gracefully. Sorts results for easier comparison.


In [None]:
MODES = globals().get("MODES", ["single", "merged"])

records = []
failures = []

for task_name in TASKS:
    for mode in MODES:
        try:
            res = run_tabpfn_on_task(DATASET, task_name, mode=mode)
            # Flatten metrics for val and test, one metric per row
            for split in ["val", "test"]:
                metrics = res.get(f"{split}_metrics") or {}
                for metric_name, metric_value in metrics.items():
                    # Only add rows with non-empty metric_value
                    if metric_value is not None and not (isinstance(metric_value, float) and np.isnan(metric_value)):
                        records.append({
                            "dataset": res.get("dataset", DATASET),
                            "task": res.get("task", task_name),
                            "mode": res.get("mode", mode),
                            "method": "TabPFN_experimental_v1.0",
                            "metric": f"{split}_{metric_name}",
                            "score": metric_value,
                        })
        except Exception as e:
            msg = f"[{DATASET} | {task_name} | {mode}] failed: {e!s}"
            print(msg)
            failures.append(msg)

results_df = pd.DataFrame.from_records(records)

if results_df.empty:
    print("No successful runs were recorded. Check the failure messages above.")
    # Create an empty, well-formed frame so downstream plotting code doesn't crash
    results_df = pd.DataFrame(
        columns=["dataset", "task", "mode", "method", "metric", "score"]
    )
else:
    # Ensure required sort keys exist even if some rows missed them
    for col in ["task", "mode"]:
        if col not in results_df.columns:
            results_df[col] = pd.NA
    # Sort only by the columns that exist to avoid KeyError
    sort_keys = [c for c in ["task", "mode", "metric"] if c in results_df.columns]
    if sort_keys:
        results_df = results_df.sort_values(sort_keys)

display(results_df)


In [None]:
# --- Persist results to CSV ---

out_dir = globals().get("OUT_DIR", "outputs")
os.makedirs(out_dir, exist_ok=True)

# Change timestamp format to "dd.mm.yyyy-hh:mm"
timestamp = time.strftime("%d.%m.%Y-%H:%M")
csv_name = f"tabpfn_{DATASET}_{timestamp}.csv"
csv_path = os.path.join(out_dir, csv_name)

results_df.to_csv(csv_path, index=False)
print(f"Saved results to: {csv_path}")
